Dataset NIHMS

Dataset is obtained from https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3955412/

Do some configuration for the outputs and load libraries and packages


In [1]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}
In [2]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
In [3]:
import sys 
import io
import os
import pickle
import numpy as np
import pandas as pd

from gensim.models import Word2Vec
from sklearn.manifold import TSNE
from IPython.display import SVG
from sklearn.model_selection import train_test_split

import gensim.downloader as api
#wv = api.load('word2vec-google-news-300')
In [4]:
### Matplotlib

import matplotlib.pyplot as plt 
from matplotlib import rcParams
import matplotlib.cm as cm
import matplotlib as mpl
from matplotlib.colors import ListedColormap
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns

dark2_colors = ['#1b9e77','#d95f02','#7570b3','#e7298a','#66a61e','#e6ab02','#a6761d','#666666']
def set_mpl_params():
    rcParams['figure.figsize'] = (10, 6)
    rcParams['figure.dpi'] = 150
    rcParams['lines.linewidth'] = 2
    rcParams['axes.facecolor'] = 'white'
    rcParams['font.size'] = 12
    rcParams['patch.edgecolor'] = 'white'
    rcParams['patch.facecolor'] = dark2_colors[0]
    #rcParams['font.family'] = 'StixGeneral'
    rcParams['font.family'] = 'sans-serif'
    rcParams['font.sans-serif'] = ['DejaVu Sans']

mpl.rcdefaults()
set_mpl_params()

pd.set_option('display.width', 500)
pd.set_option('display.max_columns', 100)


%matplotlib inline  
In [355]:
import rdkit as rdk
from rdkit import DataStructs
from rdkit import Chem
from rdkit.Chem import QED
from rdkit.Chem import AllChem
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole

import warnings
warnings.filterwarnings('ignore')

from mol2vec.features import mol2alt_sentence, MolSentence, DfVec, sentences2vec
from mol2vec.helpers import depict_identifier, plot_2D_vectors, IdentifierTable, mol_to_svg

Calculate Morgan fingerprints and break it down into substructures (words). From by mol2vec (https://github.com/samoturk/mol2vec)


In [6]:
def mol2sentence(mol, radius):

    """Calculates ECFP (Morgan fingerprint) and returns identifiers of substructures as 'sentence' (string).
    Returns a tuple with 1) a list with sentence for each radius and 2) a sentence with identifiers from all radii
    combined.
    NOTE: Words are ALWAYS reordered according to atom order in the input mol object.
    NOTE: Due to the way how Morgan FPs are generated, number of identifiers at each radius is smaller
    
    Parameters
    ----------
    mol : rdkit.Chem.rdchem.Mol
    radius : float 
        Fingerprint radius
    Returns
    -------
    identifier sentence
        List with sentences for each radius
    alternating sentence
        Sentence (list) with identifiers from all radii combined
    """
    radii = list(range(int(radius) + 1))
    info = {}
    _ = AllChem.GetMorganFingerprint(mol, radius, bitInfo=info)  # info: dictionary identifier, atom_idx, radius

    mol_atoms = [a.GetIdx() for a in mol.GetAtoms()]
    dict_atoms = {x: {r: None for r in radii} for x in mol_atoms}

    for element in info:
        for atom_idx, radius_at in info[element]:
            dict_atoms[atom_idx][radius_at] = element  # {atom number: {fp radius: identifier}}

    # iterate over all atoms and radii
    identifier_sentences = []
    
    for r in radii:  # iterate over radii to get one sentence per radius
        identifiers = []
        for atom in dict_atoms:  # iterate over atoms
            # get one sentence per radius
            identifiers.append(dict_atoms[atom][r])
        identifier_sentences.append(list(map(str, [x for x in identifiers if x])))
    
    # merge identifiers alternating radius to sentence: atom 0 radius0, atom 0 radius 1, etc.
    identifiers_alt = []
    for atom in dict_atoms:  # iterate over atoms
        for r in radii:  # iterate over radii
            identifiers_alt.append(dict_atoms[atom][r])

    alternating_sentence = map(str, [x for x in identifiers_alt if x])

    return list(identifier_sentences), list(alternating_sentence)

Takes the vectors for the word and sums them up for the sentence (compound)


In [7]:
def sentences2vec(sentences, model, unseen=None):
    """Generate vectors for each sentence (list) in a list of sentences. Vector is simply a
    sum of vectors for individual words.
    
    Parameters
    ----------
    sentences : list, array
        List with sentences
    model : word2vec.Word2Vec
        Gensim word2vec model
    unseen : None, str
        Keyword for unseen words. If None, those words are skipped.
        https://stats.stackexchange.com/questions/163005/how-to-set-the-dictionary-for-text-analysis-using-neural-networks/163032#163032
    Returns
    -------
    np.array
    """
    keys = set(model.wv.vocab.keys())
    vec = []
    if unseen:
        unseen_vec = model.wv.word_vec(unseen)

    for sentence in sentences:
        if unseen:
            vec.append(sum([model.wv.word_vec(y) if y in set(sentence) & keys
                       else unseen_vec for y in sentence]))
        else:
            vec.append(sum([model.wv.word_vec(y) for y in sentence 
                            if y in set(sentence) & keys]))
    return np.array(vec)

Creates features of the compound in 300 dimensions


In [8]:
def featurize(df, out_file, model_path, r, uncommon=None):
    """Featurize mols in a Pandas dataframe.
    SMILES are regenerated with RDKit to get canonical SMILES without chirality information.
    Parameters
    ----------
    df : dataframe
        Input Panda dataframe
    out_file : str
        Output csv
    model_path : str
        File path to pre-trained Gensim word2vec model
    r : int
        Radius of morgan fingerprint
    uncommon : str
        String to used to replace uncommon words/identifiers while training. Vector obtained for 'uncommon' will be used
        to encode new (unseen) identifiers
    Returns
    -------
    """
    # Load the model
    word2vec_model = Word2Vec.load(model_path)
    if uncommon:
        try:
            word2vec_model[uncommon]
        except KeyError:
            raise KeyError('Selected word for uncommon: %s not in vocabulary' % uncommon)

    print('Loading molecules.')
    df['ROMol'] = df.apply(lambda x: Chem.MolFromSmiles(str(x['Smiles'])), axis=1)
    print("Keeping only molecules that can be processed by RDKit.")
    df = df[df['ROMol'].notnull()]
    df['QED'] = df.apply(lambda x: QED.qed(x['ROMol']), axis=1)
    df['Smiles'] = df['ROMol'].map(Chem.MolToSmiles)  # Recreate SMILES

    print('Featurizing molecules.')
    df['mol-sentence'] = df.apply(lambda x: MolSentence(mol2sentence(x['ROMol'], r)[1]), axis=1)
    
    vectors = sentences2vec(df['mol-sentence'], word2vec_model, unseen=uncommon)
    print(vectors.shape)
    df_vec = pd.DataFrame(vectors, columns=['mol2vec-%03i' % x for x in range(vectors.shape[1])])
    df_vec.index = df.index
    df = df.join(df_vec)

    df.drop(['ROMol', 'mol-sentence'], axis=1).to_csv(out_file)
    return vectors, df

Positive dataset (Bioavailable compounds)


Loading, getting statistics, checking the dataframe

In [9]:
in_file = './BA-NIHMS.csv'
df_pos = pd.read_csv(in_file, delimiter=',', usecols=[0, 1, 2, 3], names=['ID', 'Name', 'Smiles', 'PercentF'], header=0, encoding='latin-1')  # Assume <tab> separated
df_pos.describe()
Out[9]:
PercentF
count 509.000000
mean 80.603340
std 14.820247
min 51.000000
25% 68.000000
50% 84.000000
75% 93.000000
max 99.000000
In [10]:
df_pos
Out[10]:
ID Name Smiles PercentF
0 NIHMSP001 3-Ketodesogestrel OC1(CCC2C3C(C4C(=CC(=O)CC4)CC3)C(CC12CC)=C)C#C 76.0
1 NIHMSP002 Abacavir OCC1CC(n2c3nc(nc(NC4CC4)c3nc2)N)C=C1 83.0
2 NIHMSP003 Abecarnil O(Cc1ccccc1)C=1C=CC2=NC=3C(=C2C=1)C(COC)=C(NC=... 92.0
3 NIHMSP004 Acenocoumarol O1c2c(cccc2)C(O)=C(C(CC(=O)C)c2ccc([N+](=O)[O-... 60.0
4 NIHMSP005 Acepromazine S1c2c(N(c3c1cccc3)CCCN(C)C)cc(cc2)C(=O)C 55.0
... ... ... ... ...
504 NIHMSP505 Zalcitabine O1C(CCC1N1C=CC(=NC1=O)N)CO 88.0
505 NIHMSP506 Zidovudine CC1=CN(C(=O)NC1=O)C2CC(C(O2)CO)N=[N+]=[N-] 63.0
506 NIHMSP507 Ziprasidone Clc1cc2NC(=O)Cc2cc1CCN1CCN(CC1)c1nsc2c1cccc2 60.0
507 NIHMSP508 Zolpidem O=C(N(C)C)Cc1n2C=C(C=Cc2nc1-c1ccc(cc1)C)C 72.0
508 NIHMSP509 Zonisamide S(=O)(=O)(N)Cc1noc2c1cccc2 99.0

509 rows × 4 columns

Negative dataset (NON- Bioavailable compounds)


Loading, getting statistics

In [295]:
in_file = './non-BA-NIHMS.csv'
df_neg = pd.read_csv(in_file, delimiter=',', usecols=[0, 1, 2, 3], names=['ID', 'Name', 'Smiles', 'PercentF'], header = 0, encoding='latin-1')  # Assume <tab> separated
df_neg.describe()
Out[295]:
PercentF
count 486.000000
mean 20.410267
std 16.608619
min 0.000000
25% 4.625000
50% 19.750000
75% 34.750000
max 50.000000

Assign classes, '0' to a negative dataset and '1' to a positive dataset. Join datasets within one dataframe


Loading, getting statistics

In [296]:
df_neg['Class'] = 0
print(df_neg.columns)

df_pos['Class'] = 1
print(df_pos.columns)
Index(['ID', 'Name', 'Smiles', 'PercentF', 'Class'], dtype='object')
Index(['ID', 'Name', 'Smiles', 'PercentF', 'Class'], dtype='object')
In [297]:
df = pd.concat([df_neg, df_pos])
df.reset_index(drop=True, inplace=True)
df.describe()
Out[297]:
PercentF Class
count 995.000000 995.000000
mean 51.202503 0.511558
std 33.956910 0.500118
min 0.000000 0.000000
25% 20.000000 0.000000
50% 53.000000 1.000000
75% 85.000000 1.000000
max 99.000000 1.000000

Load mol2vec vocabulary (model_300dim.pkl was downloaded from (https://github.com/samoturk/mol2vec). Create features and add them to the data frame


Loading, creating features, adding the features to the dataframe, getting statistics

In [298]:
model_path = './models/model_300dim.pkl'
out_file = 'BA-vectors.csv'

X, df = featurize(df, out_file, model_path, 2, uncommon='UNK')
Loading molecules.
Keeping only molecules that can be processed by RDKit.
Featurizing molecules.
(995, 300)
In [15]:
df['QED'].describe()
Out[15]:
count    995.000000
mean       0.593043
std        0.218834
min        0.011297
25%        0.447129
50%        0.641049
75%        0.773144
max        0.939838
Name: QED, dtype: float64

Load libraries for visualization, set color palette, scaler, number of components


In [16]:
import seaborn as sns
sns.set(rc={'figure.figsize':(16,16)})
palette = sns.color_palette("bright",2)
In [17]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.manifold import TSNE
import umap.umap_ as UMAP

#pca_model = PCA(n_components=30)
scaler = StandardScaler()
pca_model = PCA(n_components=2)
umap_model = UMAP.UMAP()

NIHMS Bio Oral Availability


In [299]:
y =  df['Class']
predY = df['PercentF'].values / 100.
In [20]:
# calculate the variance explained by the PC analysis
pca = PCA(n_components=4).fit(scaler.fit_transform(X))
var_exp = pca.explained_variance_ratio_.cumsum()*100.
print ('The 1st Principal Component explains {:03.1f} % of the variance\n'.format(var_exp[0]))
print ('The 1st and 2nd Principal Components explain {:03.1f} % of the variance\n'.format(var_exp[1]))
print ('The 1st, 2nd and 3rd Principal Components explain {:03.1f} % of the variance\n'.format(var_exp[2]))
print ('The first four Principal Components explain {:03.1f} % of the variance\n'.format(var_exp[3]))
The 1st Principal Component explains 45.0 % of the variance

The 1st and 2nd Principal Components explain 58.4 % of the variance

The 1st, 2nd and 3rd Principal Components explain 65.5 % of the variance

The first four Principal Components explain 69.4 % of the variance

In [21]:
tsne_model = TSNE(n_components=2, random_state=20, perplexity=200, n_iter=1000, metric='cosine')
tsne = tsne_model.fit_transform(X)

pca = pca_model.fit_transform(scaler.fit_transform(X))
umap = umap_model.fit_transform(scaler.fit_transform(X))
In [22]:
df_vec = pd.DataFrame()
df_vec['identifier'] = list([str(x) for x in df['ID'].values.tolist()])
df_vec.index = df_vec['identifier']
df_vec['t-SNE-c1'] = tsne.T[0]
df_vec['t-SNE-c2'] = tsne.T[1]
df_vec['PCA-c1'] = pca.T[0]
df_vec['PCA-c2'] = pca.T[1]
df_vec['UMAP-c1'] = umap.T[0]
df_vec['UMAP-c2'] = umap.T[1]
df_vec['Class'] = ['BioAvailable'  if x == 1 
                       else 'Not BioAvailable' for x in df['Class'].tolist()]
In [23]:
df_vec.head(3)
Out[23]:
identifier t-SNE-c1 t-SNE-c2 PCA-c1 PCA-c2 UMAP-c1 UMAP-c2 Class
identifier
NIHMSN001 NIHMSN001 2.562840 -9.223626 -6.342829 -7.040108 12.599118 7.784539 Not BioAvailable
NIHMSN002 NIHMSN002 -6.471902 -4.655626 -14.465902 -2.465346 10.640924 10.109997 Not BioAvailable
NIHMSN003 NIHMSN003 4.250673 -9.489751 24.679663 -27.442698 6.226831 -0.009980 Not BioAvailable
In [24]:
sns.scatterplot('t-SNE-c1','t-SNE-c2', hue='Class', data=df_vec, legend='full', palette=palette)
plt.savefig('NIHMS_tSNE_byBioAvalClass.png')
In [25]:
sns.scatterplot('PCA-c1','PCA-c2', hue='Class', data=df_vec, legend='full', palette=palette)
plt.savefig('NIHMS_PCA_byBioAvalClass.png')
In [26]:
sns.scatterplot('UMAP-c1','UMAP-c2', hue='Class', data=df_vec, legend='full', palette=palette)
plt.savefig('NIHMS_UMAP_byBioAvalClass.png')

Classification


In [300]:
from optparse import OptionParser
from time import time

import sklearn
from sklearn import svm, datasets
from sklearn.feature_extraction.text import HashingVectorizer
from sklearn import model_selection, metrics
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score, roc_auc_score
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, accuracy_score
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, StratifiedKFold, StratifiedShuffleSplit
from sklearn.model_selection import train_test_split, cross_val_score, cross_validate
from sklearn.preprocessing import StandardScaler, label_binarize, OneHotEncoder
from sklearn.neighbors import KNeighborsClassifier

from sklearn import linear_model, svm
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.svm import SVC, LinearSVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier, RandomForestClassifier, BaggingClassifier, RandomTreesEmbedding
from sklearn.naive_bayes import BernoulliNB, GaussianNB, MultinomialNB
from sklearn import datasets, feature_selection, cluster, feature_extraction
from sklearn import neighbors, decomposition, metrics
from sklearn import decomposition, feature_selection
from sklearn.feature_extraction.text import TfidfVectorizer, HashingVectorizer
from sklearn.feature_selection import SelectKBest, chi2
from sklearn.linear_model import LogisticRegression, RidgeClassifier, SGDClassifier, Perceptron, PassiveAggressiveClassifier
from sklearn.svm import LinearSVC
from sklearn.neighbors import KNeighborsClassifier, NearestCentroid
from sklearn.utils.extmath import density

import json # 
import datetime as dt # module for manipulating dates and times
from datetime import datetime
import itertools
from collections import Counter
In [301]:
from sklearn.model_selection import train_test_split, cross_val_score, cross_validate
from sklearn.utils import shuffle

RANDOM_STATE = 458
TEST_SIZE = 0.2

X, y, predY = shuffle(X, y, predY, random_state=RANDOM_STATE)


X_train, X_test, y_train, y_test, predY_train, predY_test = train_test_split(X,y,predY, test_size=TEST_SIZE, random_state=RANDOM_STATE)
X_dev, X_test, y_dev, y_test, predY_dev, predY_test = train_test_split(X_test, y_test, predY_test, test_size=0.5, random_state=RANDOM_STATE)
In [302]:
print (len(X_train))
print (len(X_dev))
print (len(X_test))
796
99
100
In [96]:
def log_ratio(x, eps): 
    return np.log(x+eps/(1-x+eps))

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

scores = [0., 1.0]
print(softmax(scores))
[0.26894142 0.73105858]

Create a Benchmark function to calculate a set of metrics for each classifier - each evaluated one at a time

In [160]:
## Benchmark the classifiers, one at a time

def log_ratio(x, eps): 
    return np.log(x+eps/(1-x+eps))

def benchmark(clf, name, X_train, X_dev, y_train, y_dev, predY_dev):
    """
    clf - the classifier
    name - its name
    
    benchmark: to create the benchmark metrics for the classification
    returns: the inputs to the results list
    """
    eps = 1e-6
    print('_' * 80)
    print("Training: ")
    print(clf)
    t0 = time()
    clf.fit(X_train, y_train)
    train_time = time() - t0
    print("train time: %0.3fs" % train_time)

    t0 = time()
    #if ('GBT Classifier' in name) : 
    #    pred = clf.predict(X_test.toarray())
    #else: 
    pred = clf.predict(X_dev)
    test_time = time() - t0
    print("test time:  %0.3fs" % test_time)

    acc_score = metrics.accuracy_score(y_dev, pred)
    print("accuracy:   %0.3f" % acc_score)
    
    f1_macro_score = metrics.f1_score(y_dev, pred, average='macro')
    print("*** F1 macro avg score:   %0.3f" % f1_macro_score)
    
    y_score = clf.predict_proba(X_dev)[:,1]
    print("predict proba function")

    auc_score = metrics.roc_auc_score(y_dev, y_score)
    print("*** AUC for ROC = %0.3f\n" % auc_score)
    
    print("classification report:")
    print(metrics.classification_report(y_dev, pred,
                                        target_names=categories))
    conf_mat = metrics.confusion_matrix(y_dev, pred)
    print("confusion matrix:")
    print(conf_mat)
    tn, fp, fn, tp = conf_mat.ravel()
    sensitivity = tp / (tp+fn) *100.
    print("\nsensitivity / recall (TPR):")
    print(np.round(sensitivity,3))
    specificity = tn / (tn+fp) *100.
    print("specificity (TNR):")
    print(np.round(specificity,3))
    CCR = ( sensitivity + specificity) / 2 
    print("Correct classification rate (CCR); balanced accuracy:")
    print(np.round(CCR,3))
    
    ybar = np.sum(predY_dev)/len(predY_dev)
    ssreg = np.sum((y_score-ybar)**2)
    sserr = np.sum((predY_dev - y_score)**2)
    sstot = np.sum((predY_dev - ybar)**2)
    R2 = 1. -  sserr / sstot
    R2 = metrics.r2_score(predY_dev,y_score)
    print("R-squared value %F:")
    #print(y_score)
    print(np.round(R2,3))
    
    mae = metrics.mean_absolute_error(predY_dev,y_score)
    print("MAE value %F:")
    print(np.round(mae,3))
    
    R2 = metrics.r2_score(log_ratio(predY_dev, eps), log_ratio(y_score, eps))
    print("R-squared value logK(%F):")
    #print(y_score)
    print(np.round(R2,3))
    
    mae = metrics.mean_absolute_error(log_ratio(predY_dev, eps), log_ratio(y_score, eps))
    print("MAE value logK(%F):")
    print(np.round(mae,3))

    print()
    clf_descr = str(clf).split('(')[0]
    return name, acc_score, f1_macro_score, auc_score, train_time, test_time
In [161]:
def create_results(X_train, X_dev, y_train, y_dev, predY_dev):
    """
    create_results: to run the classification and create the of results 
    from the battery of classifiers
    returns: an multiD list of results
    """
    results = []
    for clf, name in (
        (KNeighborsClassifier(n_neighbors=20), "k Nearest Neighbors | k=20"),
        (KNeighborsClassifier(n_neighbors=30), "k Nearest Neighbors | k=30"),
        (BernoulliNB(alpha=.01),"Bernouilli Naive Bayes"),
        (DecisionTreeClassifier(max_depth=3, random_state=RANDOM_STATE), "Decision Tree | MaxDepth=3"),
        (DecisionTreeClassifier(max_depth=7, random_state=RANDOM_STATE), "Decision Tree | MaxDepth=7"),
        (BaggingClassifier(n_estimators=10, random_state=RANDOM_STATE), "Bagging | 10 trees"),
        (BaggingClassifier(n_estimators=40, max_samples=1.0, max_features=0.4, random_state=RANDOM_STATE), "Bagging | 40 trees 40% max features"),
        (AdaBoostClassifier(n_estimators=10, random_state=RANDOM_STATE), "AdaBoost | 10 trees"),
        (AdaBoostClassifier(n_estimators=40, random_state=RANDOM_STATE), "AdaBoost | 40 trees"),
        (GradientBoostingClassifier(n_estimators=40, learning_rate=1.0, max_depth=1, random_state=RANDOM_STATE) , "Gradient Boosting | 40 trees"),
        (GradientBoostingClassifier(n_estimators=40, learning_rate=0.7, max_depth=10, random_state=RANDOM_STATE), "Gradient Boosting | 40 trees, max-depth=10"),
        (RandomForestClassifier(n_estimators=100, max_features='auto', random_state=RANDOM_STATE), "Random Forest | 100 trees"),
        (RandomForestClassifier(n_estimators=500, max_features='auto', random_state=RANDOM_STATE), "Random forest | 500 trees"), 
    ):
        print('=' * 80)
        print(name)
        results.append(benchmark(clf, name, X_train, X_dev, y_train, y_dev, predY_dev))
        
    '''
    for penalty in ["L1", "L2"]:
        print('=' * 80)
        print("%s penalty" % penalty.upper())

        # Train SGD model
        results.append(benchmark(SGDClassifier(loss="hinge", alpha=1e-4, max_iter=1000,
                                               penalty=penalty, random_state=RANDOM_STATE), "SGD classifier | hinge loss, %s penalty, alpha=1e-4" %penalty,X_train, X_dev, y_train, y_dev, predY_dev))
        results.append(benchmark(SGDClassifier(loss="hinge", alpha=1e-6, max_iter=1000,
                                       penalty=penalty, random_state=RANDOM_STATE), "SGD classifier | hinge loss %s penalty, alpha=1e-6" %penalty,X_train, X_dev, y_train, y_dev, predY_dev))
        results.append(benchmark(SGDClassifier(loss="log", alpha=.0001, max_iter=1000,
                                   penalty=penalty, random_state=RANDOM_STATE), "SGD classifier | log loss %s penalty" %penalty,X_train, X_dev, y_train, y_dev, predY_dev))
    # Train SGD with Elastic Net penalty
    print('=' * 80)
    print("Elastic-Net penalty")
    results.append(benchmark(SGDClassifier(alpha=.0001, max_iter=10000,
                                           penalty="elasticnet", random_state=RANDOM_STATE),"SGD classifier | Elastic-Net penalty",X_train, X_dev, y_train, y_dev, predY_dev))
    '''
    
    return results
In [162]:
def comparison_plots(results):
    """
    results - array containing the results from the classification to plot
    
    yields: prints out the results from each classifier and then finishes with plots of the 
    accuracy scores and ROC AUC scores for all the classifiers
    """
    ######
    # Plotting logistics
    indices = np.arange(len(results))
    results = [[x[i] for x in results] for i in range(len(results[0]))]
    clf_names, acc_score, f1_macro, auc_score, training_time, test_time = results
 
    data_tuples = list(zip(clf_names,acc_score, f1_macro, auc_score ))
    dataframe_to_plot = pd.DataFrame(data_tuples, columns=['Classifier','Accuracy Score', 'F1 Macro-Avg Score', 'AUC of ROC curve'])

    sns.set(style="whitegrid")
    qualitative_colors = sns.color_palette("Set2", 10)
    sns.set_palette(qualitative_colors)
    sns.set(rc={'figure.figsize':(12,16)})

    ax = sns.barplot(x='Accuracy Score', y='Classifier', data=dataframe_to_plot, palette = qualitative_colors)
    ax.set_yticklabels('')
    for i, pos in enumerate(indices):
        ax.annotate(clf_names[i], (0.02, pos+0.2))
    for i, pos in enumerate(indices):
        ax.annotate(str(round(acc_score[i]*100,2))+'%', (acc_score[i]-0.07, pos+0.2)) 
    plt.show()
    
    ax = sns.barplot(x='F1 Macro-Avg Score', y='Classifier', data=dataframe_to_plot, palette = qualitative_colors)
    ax.set_yticklabels('')
    for i, pos in enumerate(indices):
        ax.annotate(clf_names[i], (0.02, pos+0.2))
    for i, pos in enumerate(indices):
        ax.annotate(str(round(f1_macro[i]*100,2))+'%', (f1_macro[i]-0.07, pos+0.2))
    plt.show()
    
    ax = sns.barplot(x='AUC of ROC curve', y='Classifier', data=dataframe_to_plot, palette = qualitative_colors)
    ax.set_yticklabels('')
    for i, pos in enumerate(indices):
        ax.annotate(clf_names[i], (0.02, pos+0.2))
    for i, pos in enumerate(indices):
        ax.annotate(str(round(auc_score[i]*100,2))+'%', (auc_score[i]-0.07, pos+0.2))
    plt.show()
    

Run the classifiers

In [163]:
# options
print_top10 = True
n_features = 2 ** 16
filtered = True
target_names = ['pos', 'neg']
categories = ['Postive', 'Negative']
In [164]:
results1 = create_results(X_train, X_dev, y_train, y_dev, predY_dev)
================================================================================
k Nearest Neighbors | k=20
________________________________________________________________________________
Training: 
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
                     metric_params=None, n_jobs=None, n_neighbors=20, p=2,
                     weights='uniform')
train time: 0.021s
test time:  0.035s
accuracy:   0.697
*** F1 macro avg score:   0.694
predict proba function
*** AUC for ROC = 0.760

classification report:
              precision    recall  f1-score   support

     Postive       0.67      0.67      0.67        45
    Negative       0.72      0.72      0.72        54

    accuracy                           0.70        99
   macro avg       0.69      0.69      0.69        99
weighted avg       0.70      0.70      0.70        99

confusion matrix:
[[30 15]
 [15 39]]

sensitivity / recall (TPR):
72.222
specificity (TNR):
66.667
Correct classification rate (CCR); balanced accuracy:
69.444
R-squared value %F:
0.192
MAE value %F:
0.258
R-squared value logK(%F):
0.013
MAE value logK(%F):
1.173

================================================================================
k Nearest Neighbors | k=30
________________________________________________________________________________
Training: 
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
                     metric_params=None, n_jobs=None, n_neighbors=30, p=2,
                     weights='uniform')
train time: 0.008s
test time:  0.031s
accuracy:   0.677
*** F1 macro avg score:   0.673
predict proba function
*** AUC for ROC = 0.759

classification report:
              precision    recall  f1-score   support

     Postive       0.65      0.62      0.64        45
    Negative       0.70      0.72      0.71        54

    accuracy                           0.68        99
   macro avg       0.67      0.67      0.67        99
weighted avg       0.68      0.68      0.68        99

confusion matrix:
[[28 17]
 [15 39]]

sensitivity / recall (TPR):
72.222
specificity (TNR):
62.222
Correct classification rate (CCR); balanced accuracy:
67.222
R-squared value %F:
0.223
MAE value %F:
0.255
R-squared value logK(%F):
-0.01
MAE value logK(%F):
1.176

================================================================================
Bernouilli Naive Bayes
________________________________________________________________________________
Training: 
BernoulliNB(alpha=0.01, binarize=0.0, class_prior=None, fit_prior=True)
train time: 0.003s
test time:  0.001s
accuracy:   0.646
*** F1 macro avg score:   0.644
predict proba function
*** AUC for ROC = 0.712

classification report:
              precision    recall  f1-score   support

     Postive       0.61      0.62      0.62        45
    Negative       0.68      0.67      0.67        54

    accuracy                           0.65        99
   macro avg       0.64      0.64      0.64        99
weighted avg       0.65      0.65      0.65        99

confusion matrix:
[[28 17]
 [18 36]]

sensitivity / recall (TPR):
66.667
specificity (TNR):
62.222
Correct classification rate (CCR); balanced accuracy:
64.444
R-squared value %F:
-1.079
MAE value %F:
0.382
R-squared value logK(%F):
-2.228
MAE value logK(%F):
3.18

================================================================================
Decision Tree | MaxDepth=3
________________________________________________________________________________
Training: 
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
                       max_depth=3, max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort='deprecated',
                       random_state=458, splitter='best')
train time: 0.055s
test time:  0.000s
accuracy:   0.737
*** F1 macro avg score:   0.729
predict proba function
*** AUC for ROC = 0.718

classification report:
              precision    recall  f1-score   support

     Postive       0.76      0.62      0.68        45
    Negative       0.73      0.83      0.78        54

    accuracy                           0.74        99
   macro avg       0.74      0.73      0.73        99
weighted avg       0.74      0.74      0.73        99

confusion matrix:
[[28 17]
 [ 9 45]]

sensitivity / recall (TPR):
83.333
specificity (TNR):
62.222
Correct classification rate (CCR); balanced accuracy:
72.778
R-squared value %F:
0.082
MAE value %F:
0.251
R-squared value logK(%F):
0.002
MAE value logK(%F):
1.19

================================================================================
Decision Tree | MaxDepth=7
________________________________________________________________________________
Training: 
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
                       max_depth=7, max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort='deprecated',
                       random_state=458, splitter='best')
train time: 0.099s
test time:  0.000s
accuracy:   0.677
*** F1 macro avg score:   0.675
predict proba function
*** AUC for ROC = 0.680

classification report:
              precision    recall  f1-score   support

     Postive       0.64      0.67      0.65        45
    Negative       0.71      0.69      0.70        54

    accuracy                           0.68        99
   macro avg       0.67      0.68      0.68        99
weighted avg       0.68      0.68      0.68        99

confusion matrix:
[[30 15]
 [17 37]]

sensitivity / recall (TPR):
68.519
specificity (TNR):
66.667
Correct classification rate (CCR); balanced accuracy:
67.593
R-squared value %F:
-0.553
MAE value %F:
0.313
R-squared value logK(%F):
-2.407
MAE value logK(%F):
2.771

================================================================================
Bagging | 10 trees
________________________________________________________________________________
Training: 
BaggingClassifier(base_estimator=None, bootstrap=True, bootstrap_features=False,
                  max_features=1.0, max_samples=1.0, n_estimators=10,
                  n_jobs=None, oob_score=False, random_state=458, verbose=0,
                  warm_start=False)
train time: 0.726s
test time:  0.001s
accuracy:   0.717
*** F1 macro avg score:   0.717
predict proba function
*** AUC for ROC = 0.754

classification report:
              precision    recall  f1-score   support

     Postive       0.67      0.76      0.71        45
    Negative       0.77      0.69      0.73        54

    accuracy                           0.72        99
   macro avg       0.72      0.72      0.72        99
weighted avg       0.72      0.72      0.72        99

confusion matrix:
[[34 11]
 [17 37]]

sensitivity / recall (TPR):
68.519
specificity (TNR):
75.556
Correct classification rate (CCR); balanced accuracy:
72.037
R-squared value %F:
0.124
MAE value %F:
0.255
R-squared value logK(%F):
-0.838
MAE value logK(%F):
1.699

================================================================================
Bagging | 40 trees 40% max features
________________________________________________________________________________
Training: 
BaggingClassifier(base_estimator=None, bootstrap=True, bootstrap_features=False,
                  max_features=0.4, max_samples=1.0, n_estimators=40,
                  n_jobs=None, oob_score=False, random_state=458, verbose=0,
                  warm_start=False)
train time: 1.127s
test time:  0.004s
accuracy:   0.808
*** F1 macro avg score:   0.806
predict proba function
*** AUC for ROC = 0.833

classification report:
              precision    recall  f1-score   support

     Postive       0.80      0.78      0.79        45
    Negative       0.82      0.83      0.83        54

    accuracy                           0.81        99
   macro avg       0.81      0.81      0.81        99
weighted avg       0.81      0.81      0.81        99

confusion matrix:
[[35 10]
 [ 9 45]]

sensitivity / recall (TPR):
83.333
specificity (TNR):
77.778
Correct classification rate (CCR); balanced accuracy:
80.556
R-squared value %F:
0.278
MAE value %F:
0.224
R-squared value logK(%F):
-0.026
MAE value logK(%F):
1.179

================================================================================
AdaBoost | 10 trees
________________________________________________________________________________
Training: 
AdaBoostClassifier(algorithm='SAMME.R', base_estimator=None, learning_rate=1.0,
                   n_estimators=10, random_state=458)
train time: 0.164s
test time:  0.001s
accuracy:   0.697
*** F1 macro avg score:   0.696
predict proba function
*** AUC for ROC = 0.782

classification report:
              precision    recall  f1-score   support

     Postive       0.65      0.71      0.68        45
    Negative       0.74      0.69      0.71        54

    accuracy                           0.70        99
   macro avg       0.70      0.70      0.70        99
weighted avg       0.70      0.70      0.70        99

confusion matrix:
[[32 13]
 [17 37]]

sensitivity / recall (TPR):
68.519
specificity (TNR):
71.111
Correct classification rate (CCR); balanced accuracy:
69.815
R-squared value %F:
0.077
MAE value %F:
0.288
R-squared value logK(%F):
-0.045
MAE value logK(%F):
1.239

================================================================================
AdaBoost | 40 trees
________________________________________________________________________________
Training: 
AdaBoostClassifier(algorithm='SAMME.R', base_estimator=None, learning_rate=1.0,
                   n_estimators=40, random_state=458)
train time: 0.655s
test time:  0.005s
accuracy:   0.646
*** F1 macro avg score:   0.643
predict proba function
*** AUC for ROC = 0.768

classification report:
              precision    recall  f1-score   support

     Postive       0.61      0.60      0.61        45
    Negative       0.67      0.69      0.68        54

    accuracy                           0.65        99
   macro avg       0.64      0.64      0.64        99
weighted avg       0.65      0.65      0.65        99

confusion matrix:
[[27 18]
 [17 37]]

sensitivity / recall (TPR):
68.519
specificity (TNR):
60.0
Correct classification rate (CCR); balanced accuracy:
64.259
R-squared value %F:
0.016
MAE value %F:
0.297
R-squared value logK(%F):
-0.063
MAE value logK(%F):
1.259

================================================================================
Gradient Boosting | 40 trees
________________________________________________________________________________
Training: 
GradientBoostingClassifier(ccp_alpha=0.0, criterion='friedman_mse', init=None,
                           learning_rate=1.0, loss='deviance', max_depth=1,
                           max_features=None, max_leaf_nodes=None,
                           min_impurity_decrease=0.0, min_impurity_split=None,
                           min_samples_leaf=1, min_samples_split=2,
                           min_weight_fraction_leaf=0.0, n_estimators=40,
                           n_iter_no_change=None, presort='deprecated',
                           random_state=458, subsample=1.0, tol=0.0001,
                           validation_fraction=0.1, verbose=0,
                           warm_start=False)
train time: 0.604s
test time:  0.001s
accuracy:   0.697
*** F1 macro avg score:   0.695
predict proba function
*** AUC for ROC = 0.749

classification report:
              precision    recall  f1-score   support

     Postive       0.66      0.69      0.67        45
    Negative       0.73      0.70      0.72        54

    accuracy                           0.70        99
   macro avg       0.70      0.70      0.70        99
weighted avg       0.70      0.70      0.70        99

confusion matrix:
[[31 14]
 [16 38]]

sensitivity / recall (TPR):
70.37
specificity (TNR):
68.889
Correct classification rate (CCR); balanced accuracy:
69.63
R-squared value %F:
-0.072
MAE value %F:
0.26
R-squared value logK(%F):
0.112
MAE value logK(%F):
1.183

================================================================================
Gradient Boosting | 40 trees, max-depth=10
________________________________________________________________________________
Training: 
GradientBoostingClassifier(ccp_alpha=0.0, criterion='friedman_mse', init=None,
                           learning_rate=0.7, loss='deviance', max_depth=10,
                           max_features=None, max_leaf_nodes=None,
                           min_impurity_decrease=0.0, min_impurity_split=None,
                           min_samples_leaf=1, min_samples_split=2,
                           min_weight_fraction_leaf=0.0, n_estimators=40,
                           n_iter_no_change=None, presort='deprecated',
                           random_state=458, subsample=1.0, tol=0.0001,
                           validation_fraction=0.1, verbose=0,
                           warm_start=False)
train time: 1.860s
test time:  0.000s
accuracy:   0.677
*** F1 macro avg score:   0.675
predict proba function
*** AUC for ROC = 0.769

classification report:
              precision    recall  f1-score   support

     Postive       0.64      0.67      0.65        45
    Negative       0.71      0.69      0.70        54

    accuracy                           0.68        99
   macro avg       0.67      0.68      0.68        99
weighted avg       0.68      0.68      0.68        99

confusion matrix:
[[30 15]
 [17 37]]

sensitivity / recall (TPR):
68.519
specificity (TNR):
66.667
Correct classification rate (CCR); balanced accuracy:
67.593
R-squared value %F:
-0.379
MAE value %F:
0.287
R-squared value logK(%F):
-0.254
MAE value logK(%F):
1.828

================================================================================
Random Forest | 100 trees
________________________________________________________________________________
Training: 
RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=100,
                       n_jobs=None, oob_score=False, random_state=458,
                       verbose=0, warm_start=False)
train time: 0.561s
test time:  0.008s
accuracy:   0.788
*** F1 macro avg score:   0.786
predict proba function
*** AUC for ROC = 0.822

classification report:
              precision    recall  f1-score   support

     Postive       0.76      0.78      0.77        45
    Negative       0.81      0.80      0.80        54

    accuracy                           0.79        99
   macro avg       0.79      0.79      0.79        99
weighted avg       0.79      0.79      0.79        99

confusion matrix:
[[35 10]
 [11 43]]

sensitivity / recall (TPR):
79.63
specificity (TNR):
77.778
Correct classification rate (CCR); balanced accuracy:
78.704
R-squared value %F:
0.293
MAE value %F:
0.231
R-squared value logK(%F):
0.063
MAE value logK(%F):
1.102

================================================================================
Random forest | 500 trees
________________________________________________________________________________
Training: 
RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=500,
                       n_jobs=None, oob_score=False, random_state=458,
                       verbose=0, warm_start=False)
train time: 2.656s
test time:  0.033s
accuracy:   0.778
*** F1 macro avg score:   0.775
predict proba function
*** AUC for ROC = 0.833

classification report:
              precision    recall  f1-score   support

     Postive       0.77      0.73      0.75        45
    Negative       0.79      0.81      0.80        54

    accuracy                           0.78        99
   macro avg       0.78      0.77      0.77        99
weighted avg       0.78      0.78      0.78        99

confusion matrix:
[[33 12]
 [10 44]]

sensitivity / recall (TPR):
81.481
specificity (TNR):
73.333
Correct classification rate (CCR); balanced accuracy:
77.407
R-squared value %F:
0.296
MAE value %F:
0.231
R-squared value logK(%F):
0.05
MAE value logK(%F):
1.113

In [165]:
comparison_plots(results1)
In [166]:
def compare_trees(X_train, X_test, y_train, y_test):
    """"
    X_train - training set data features
    X_test - validation set data features
    y_train - training set data labels
    y_test - validation set truth
    
    compare_trees: function to run combo tree-based classifiers with Logistic Regresion
    and to plot comparable ROC curves for them
    yields: plots of the ROC curves and AUC scores
    """
    n_estimator = 40
    X_train, X_train_lr, y_train, y_train_lr = train_test_split(X_train, y_train, test_size=0.5)

    # Unsupervised transformation based on totally random trees
    rt = RandomTreesEmbedding(n_estimators=40, max_depth=1, random_state=RANDOM_STATE)

    rt_lm = LogisticRegression()
    pipeline = make_pipeline(rt, rt_lm)
    pipeline.fit(X_train, y_train)
    y_pred_rt = pipeline.predict_proba(X_test)[:, 1]
    fpr_rt_lm, tpr_rt_lm, _ = roc_curve(y_test, y_pred_rt)

    # Supervised transformation based on random forests
    rf = RandomForestClassifier(max_depth=3, n_estimators=200, random_state=RANDOM_STATE)
    rf_enc = OneHotEncoder()
    rf_lm = LogisticRegression()
    rf.fit(X_train, y_train)
    rf_enc.fit(rf.apply(X_train))
    rf_lm.fit(rf_enc.transform(rf.apply(X_train_lr)), y_train_lr)

    y_pred_rf_lm = rf_lm.predict_proba(rf_enc.transform(rf.apply(X_test)))[:, 1]
    fpr_rf_lm, tpr_rf_lm, _ = roc_curve(y_test, y_pred_rf_lm)

    grd = GradientBoostingClassifier(n_estimators=40, learning_rate=1.0, max_depth=1, random_state=RANDOM_STATE)
    grd_enc = OneHotEncoder()
    grd_lm = LogisticRegression()
    grd.fit(X_train, y_train)
    grd_enc.fit(grd.apply(X_train)[:, :, 0])
    grd_lm.fit(grd_enc.transform(grd.apply(X_train_lr)[:, :, 0]), y_train_lr)

    y_pred_grd_lm = grd_lm.predict_proba(
        grd_enc.transform(grd.apply(X_test)[:, :, 0]))[:, 1]
    fpr_grd_lm, tpr_grd_lm, _ = roc_curve(y_test, y_pred_grd_lm)

    # The gradient boosted model by itself
    y_pred_grd = grd.predict_proba(np.asarray(X_test))[:, 1]
    fpr_grd, tpr_grd, _ = roc_curve(y_test, y_pred_grd)

    # The random forest model by itself
    y_pred_rf = rf.predict_proba(X_test)[:, 1]
    fpr_rf, tpr_rf, _ = roc_curve(y_test, y_pred_rf)

    plt.figure(1)
    plt.figure(figsize=(10,8))
    plt.plot([0, 1], [0, 1], 'k--')
    plt.plot(fpr_rt_lm, tpr_rt_lm, label='RT + LR (area = {0:0.3f})'
                       ''.format(auc(fpr_rt_lm, tpr_rt_lm)))
    plt.plot(fpr_rf, tpr_rf, label='RF (area = {0:0.3f})'
                       ''.format(auc(fpr_rf, tpr_rf)))
    plt.plot(fpr_rf_lm, tpr_rf_lm, label='RF + LR (area = {0:0.3f})'
                       ''.format(auc(fpr_rf_lm, tpr_rf_lm)))
    plt.plot(fpr_grd, tpr_grd, label='GBT (area = {0:0.3f})'
                       ''.format(auc(fpr_grd, tpr_grd)))
    plt.plot(fpr_grd_lm, tpr_grd_lm, label='GBT + LR (area = {0:0.3f})'
                       ''.format(auc(fpr_grd_lm, tpr_grd_lm)))
    plt.xlabel('False positive rate')
    plt.ylabel('True positive rate')
    plt.title('ROC curve')
    plt.legend(loc='best')
    plt.show()

    plt.figure(2)
    plt.figure(figsize=(10,8))
    plt.xlim(0, 0.4)
    plt.ylim(0.6, 1)
    plt.plot([0, 1], [0, 1], 'k--')
    plt.plot(fpr_rt_lm, tpr_rt_lm, label='RT + LR (area = {0:0.3f})'
                       ''.format(auc(fpr_rt_lm, tpr_rt_lm)))
    plt.plot(fpr_rf, tpr_rf, label='RF (area = {0:0.3f})'
                       ''.format(auc(fpr_rf, tpr_rf)))
    plt.plot(fpr_rf_lm, tpr_rf_lm, label='RF + LR (area = {0:0.3f})'
                       ''.format(auc(fpr_rf_lm, tpr_rf_lm)))
    plt.plot(fpr_grd, tpr_grd, label='GBT (area = {0:0.3f})'
                       ''.format(auc(fpr_grd, tpr_grd)))
    plt.plot(fpr_grd_lm, tpr_grd_lm, label='GBT + LR (area = {0:0.3f})'
                       ''.format(auc(fpr_grd_lm, tpr_grd_lm)))
    plt.xlabel('False positive rate')
    plt.ylabel('True positive rate')
    plt.title('ROC curve (zoomed in at top left)')
    plt.legend(loc='best')
    plt.show()
    
In [36]:
compare_trees(X_train, X_test, y_train, y_test)
<Figure size 864x1152 with 0 Axes>
<Figure size 864x1152 with 0 Axes>

Parameter tuning

In [158]:
def parameter_tuning_plot(clf_scores, params, measure, title):
    """
    clf_scores - list of scores from the clf classification
    params - list of parameters that were used in the tuning 
    measure - measure to used for the xlabel text
    title - test for the title of the plot
    
    
    parameter_tuning_plot: plot out the CV tuning scores in a line plot
    yields: a line plot with scattered points
    """

    index = range(1,len(params)+1)
    plt.subplots(figsize=(10,8))
    score_means = np.mean(clf_scores, axis=1)
    score_std = np.std(clf_scores, axis=1)
    score_medians = np.median(clf_scores, axis=1)
    plt.scatter(index,score_means, c='g', zorder=3, s=100, marker='o', label= 'Mean of ROC AUC scores')
    plt.errorbar(index, score_means, yerr = 2*score_std,color='#fcba03', alpha =0.7, capsize=10, elinewidth=4, linestyle="None", zorder = 1, label= 'SE of ROC AUC scores')
    plt.title(title)
    plt.legend(frameon=False, loc='lower right')
    plt.ylabel('ROC AUC scores')
    plt.xlabel(measure)
    plt.xticks(index, params, rotation=90)
    legend = plt.legend(loc='lower center',frameon=True,framealpha=0.6, numpoints=1, scatterpoints=1)
    rect = legend.get_frame()
    rect.set_facecolor('#D3D3D3')
    rect.set_linewidth(0.6)
    plt.gca().grid(which='major', axis='y')
    plt.show

Decision tree classifier - which max depth?

In [121]:
Params = [ 1.0, 2.0, 5.0, 7.0, 10.0, 15.0, 100.]
N_param = len(Params)
cv = 5

clf_scores = np.zeros((N_param,cv))
cols = 0
for param in Params:
    clf = DecisionTreeClassifier(max_depth=param, random_state=RANDOM_STATE)
    clf_scores[cols,:] = cross_val_score(clf, X, y, cv=cv, scoring = 'roc_auc' , n_jobs = 8)
    cols += 1
parameter_tuning_plot(clf_scores, Params, 'Max depth', 'ROC AUC scores by penalty parameter')

Bagging classifier - how many trees?

In [159]:
Params = [20, 40, 50, 100, 200, 500]
N_param = len(Params)
cv = 5

clf_scores = np.zeros((N_param,cv))
cols = 0
for param in Params:
    
    clf = BaggingClassifier(n_estimators=param, max_samples=1.0, max_features=0.4, random_state=RANDOM_STATE)
    clf_scores[cols,:] = cross_val_score(clf, X, y, cv=cv, scoring = 'roc_auc' )
    cols += 1
parameter_tuning_plot(clf_scores, Params, '40 trees 40% max features', 'ROC AUC scores by penalty parameter')

Gradient Boosting Trees classifier - how many trees?

In [128]:
Params = [10, 20, 50, 100, 200]
N_param = len(Params)
cv = 5

clf_scores = np.zeros((N_param,cv))
cols = 0
for param in Params:
    clf = GradientBoostingClassifier(n_estimators=param, learning_rate=0.7, max_depth=10, random_state=RANDOM_STATE)
    clf_scores[cols,:] = cross_val_score(clf, X, y, cv=cv, scoring = 'roc_auc' )
    cols += 1
parameter_tuning_plot(clf_scores, Params, 'Number of trees, learning rate 0.7, max-depth=10', 'ROC AUC scores by penalty parameter')

Random Forests - how many trees?

In [197]:
Params = [50, 100, 300, 500, 1000]
N_param = len(Params)
cv = 5

clf_scores = np.zeros((N_param,cv))
cols = 0
for param in Params:
    clf = RandomForestClassifier(n_estimators=param, max_features='auto', random_state=RANDOM_STATE,n_jobs=8)
    clf_scores[cols,:] = cross_val_score(clf, X, y, cv=cv, scoring = 'roc_auc' , n_jobs = 8)
    cols += 1
parameter_tuning_plot(clf_scores, Params, 'Number of trees in the Random Forest', 'ROC AUC scores by choice of number of trees' )

kNN - how many neighbors?

In [215]:
Params = [20, 30 , 40, 50, 60]
N_param = len(Params)
cv = 5

clf_scores = np.zeros((N_param,cv))
cols = 0
for param in Params:
    clf = KNeighborsClassifier(n_neighbors=20)
    clf_scores[cols,:] = cross_val_score(clf, X, y, cv=cv, scoring = 'roc_auc' , n_jobs = 8)
    cols += 1
    
parameter_tuning_plot(clf_scores, Params, 'Regularisation Parameter', 'ROC AUC score by degree of regularisation')

Logistic Regression - which regularization parameter C?

In [126]:
Params = [1e-2,1e-1, 1.0, 1e1, 1e2, 1e3]
N_param = len(Params)
cv = 5

clf_scores = np.zeros((N_param,cv))
cols = 0
for param in Params:
    clf = LogisticRegression(penalty ='l1', C=param, solver='liblinear', random_state=RANDOM_STATE)
    clf_scores[cols,:] = cross_val_score(clf, X, y, cv=cv, scoring = 'roc_auc' , n_jobs = 8)
    cols += 1
    
parameter_tuning_plot(clf_scores, Params, 'Regularisation Parameter', 'ROC AUC score by degree of regularisation')

Predicting the class from the test data

In [330]:
# Benchmark the classifiers, one at a time
def benchmark(clf, name, X_train, X_dev, y_train, y_dev, predY_dev):
    """
    clf - the classifier
    name - its name
    
    benchmark: to create the benchmark metrics for the classification
    returns: the inputs to the results list
    """
    eps = 1e-6
    print('_' * 80)
    print("Training: ")
    print(clf)
    t0 = time()
    clf.fit(X_train, y_train)
    train_time = time() - t0
    print("train time: %0.3fs" % train_time)

    t0 = time()
 
    pred = clf.predict(X_dev)
    test_time = time() - t0
    print("test time:  %0.3fs" % test_time)

    acc_score = metrics.accuracy_score(y_dev, pred)
    print("accuracy:   %0.3f" % acc_score)
    
    f1_macro_score = metrics.f1_score(y_dev, pred, average='macro')
    print("*** F1 macro avg score:   %0.3f" % f1_macro_score)
    
    y_score = clf.predict_proba(X_dev)[:,1]

    auc_score = metrics.roc_auc_score(y_dev, y_score)
    print("*** AUC for ROC = %0.3f\n" % auc_score)
    
    if hasattr(clf, 'coef_'):
        print("dimensionality: %d" % clf.coef_.shape[1])
        print("density: %f" % density(clf.coef_))
    
    print("classification report:")
    print(metrics.classification_report(y_dev, pred,
                                        target_names=categories))
    conf_mat = metrics.confusion_matrix(y_dev, pred)
    print("confusion matrix:")
    print(conf_mat)
    tn, fp, fn, tp = conf_mat.ravel()
    sensitivity = tp / (tp+fn) *100.
    print("\nsensitivity / recall (TPR):")
    print(np.round(sensitivity,3))
    specificity = tn / (tn+fp) *100.
    print("specificity (TNR):")
    print(np.round(specificity,3))
    CCR = ( sensitivity + specificity) / 2 
    print("Correct classification rate (CCR); balanced accuracy:")
    print(np.round(CCR,3))
    R2 = metrics.r2_score(predY_dev,y_score)
    print("R-squared value %F:")
    print(np.round(R2,3))
    mae = metrics.mean_absolute_error(predY_dev,y_score)
    print("MAE value %F:")
    print(np.round(mae,3))
    R2 = metrics.r2_score(log_ratio(predY_dev, eps), log_ratio(y_score, eps))
    print("R-squared value logK(%F):")
    print(np.round(R2,3))
    mae = metrics.mean_absolute_error(log_ratio(predY_dev, eps), log_ratio(y_score, eps))
    print("MAE value logK(%F):")
    print(np.round(mae,3))
    
    print()
    clf_descr = str(clf).split('(')[0]
    return name, acc_score, f1_macro_score, auc_score, train_time, test_time
In [331]:
results_test = []
In [332]:
dtc = DecisionTreeClassifier(max_depth=5, random_state=RANDOM_STATE)
results_test.append(benchmark(dtc, 'Decision Tree classifier with 2 trees', X_train, X_test, y_train, y_test, predY_test))
________________________________________________________________________________
Training: 
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
                       max_depth=5, max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort='deprecated',
                       random_state=458, splitter='best')
train time: 0.068s
test time:  0.000s
accuracy:   0.630
*** F1 macro avg score:   0.628
*** AUC for ROC = 0.642

classification report:
              precision    recall  f1-score   support

     Postive       0.61      0.60      0.60        47
    Negative       0.65      0.66      0.65        53

    accuracy                           0.63       100
   macro avg       0.63      0.63      0.63       100
weighted avg       0.63      0.63      0.63       100

confusion matrix:
[[28 19]
 [18 35]]

sensitivity / recall (TPR):
66.038
specificity (TNR):
59.574
Correct classification rate (CCR); balanced accuracy:
62.806
R-squared value %F:
-0.395
MAE value %F:
0.307
R-squared value logK(%F):
-2.829
MAE value logK(%F):
1.738

In [333]:
bc = BaggingClassifier(n_estimators=500, max_samples=1.0, max_features=0.4, random_state=RANDOM_STATE)
results_test.append(benchmark(bc,  "Bagging | 500 trees 40% max features", X_train, X_test, y_train, y_test, predY_test))
________________________________________________________________________________
Training: 
BaggingClassifier(base_estimator=None, bootstrap=True, bootstrap_features=False,
                  max_features=0.4, max_samples=1.0, n_estimators=500,
                  n_jobs=None, oob_score=False, random_state=458, verbose=0,
                  warm_start=False)
train time: 14.318s
test time:  0.040s
accuracy:   0.770
*** F1 macro avg score:   0.770
*** AUC for ROC = 0.817

classification report:
              precision    recall  f1-score   support

     Postive       0.73      0.81      0.77        47
    Negative       0.81      0.74      0.77        53

    accuracy                           0.77       100
   macro avg       0.77      0.77      0.77       100
weighted avg       0.77      0.77      0.77       100

confusion matrix:
[[38  9]
 [14 39]]

sensitivity / recall (TPR):
73.585
specificity (TNR):
80.851
Correct classification rate (CCR); balanced accuracy:
77.218
R-squared value %F:
0.341
MAE value %F:
0.225
R-squared value logK(%F):
0.108
MAE value logK(%F):
0.829

In [334]:
rf = RandomForestClassifier(n_estimators=1000, max_features='auto', random_state=RANDOM_STATE)
results_test.append(benchmark(rf,  "Random Forest | 1000 trees", X_train, X_test, y_train, y_test, predY_test))
________________________________________________________________________________
Training: 
RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=1000,
                       n_jobs=None, oob_score=False, random_state=458,
                       verbose=0, warm_start=False)
train time: 4.898s
test time:  0.066s
accuracy:   0.750
*** F1 macro avg score:   0.750
*** AUC for ROC = 0.815

classification report:
              precision    recall  f1-score   support

     Postive       0.71      0.79      0.75        47
    Negative       0.79      0.72      0.75        53

    accuracy                           0.75       100
   macro avg       0.75      0.75      0.75       100
weighted avg       0.75      0.75      0.75       100

confusion matrix:
[[37 10]
 [15 38]]

sensitivity / recall (TPR):
71.698
specificity (TNR):
78.723
Correct classification rate (CCR); balanced accuracy:
75.211
R-squared value %F:
0.336
MAE value %F:
0.229
R-squared value logK(%F):
0.094
MAE value logK(%F):
0.839

In [335]:
comparison_plots(results_test)
In [191]:
def class_scores(y_dev, pred):
    """
    To print out metrics for classifier results
    y: True labels or binary label indicators
    pred: Target scores. In the binary and multilabel cases, these can be either 
    probability estimates or non-thresholded decision values (as returned by 
    decision_function on some classifiers). In the multiclass case, these must 
    be probability estimates which sum to 1.
    """
    eps = 1e-6
    print('_' * 80)
    y_pred = (pred >= 0.5).astype(bool)
    acc_score = metrics.accuracy_score(y_dev, y_pred)
    print("accuracy:   %0.3f" % acc_score)
    
    f1_macro_score = metrics.f1_score(y_dev, y_pred, average='macro')
    print("*** F1 macro avg score:   %0.3f" % f1_macro_score)

    auc_score = metrics.roc_auc_score(y_dev, pred)
    print("*** AUC for ROC = %0.3f\n" % auc_score)

    print("classification report:")
    print(metrics.classification_report(y_dev, y_pred,
                                        target_names=categories))

    conf_mat = metrics.confusion_matrix(y_dev, y_pred)
    print("confusion matrix:")
    print(conf_mat)
    tn, fp, fn, tp = conf_mat.ravel()
    sensitivity = tp / (tp+fn) *100.
    print("\nsensitivity / recall (TPR):")
    print(np.round(sensitivity,3))
    specificity = tn / (tn+fp) *100.
    print("specificity (TNR):")
    print(np.round(specificity,3))
    CCR = ( sensitivity + specificity) / 2 
    print("Correct classification rate (CCR):")
    print(np.round(CCR,3))
    
    R2 = metrics.r2_score(y_dev,pred)
    print("R-squared value %F:")
    print(np.round(R2,3))
    mae = metrics.mean_absolute_error(y_dev,pred)
    print("MAE value %F:")
    print(np.round(mae,3))
    R2 = metrics.r2_score(log_ratio(y_dev, eps), log_ratio(pred, eps))
    print("R-squared value logK(%F):")
    print(np.round(R2,3))
    mae = metrics.mean_absolute_error(log_ratio(y_dev, eps), log_ratio(pred, eps))
    print("MAE value logK(%F):")
    print(np.round(mae,3))
    

Compare with unbalanced classification methods - performance on the test set

In [192]:
from sklearn.metrics import balanced_accuracy_score
from imblearn.ensemble import RUSBoostClassifier
rusboost = RUSBoostClassifier(n_estimators=200, algorithm='SAMME.R',
                              random_state=RANDOM_STATE)
rusboost.fit(X_train, y_train)  

y_pred = rusboost.predict(X_test)
class_scores(y_test, y_pred) 
________________________________________________________________________________
accuracy:   0.700
*** F1 macro avg score:   0.700
*** AUC for ROC = 0.700

classification report:
              precision    recall  f1-score   support

     Postive       0.67      0.70      0.69        47
    Negative       0.73      0.70      0.71        53

    accuracy                           0.70       100
   macro avg       0.70      0.70      0.70       100
weighted avg       0.70      0.70      0.70       100

confusion matrix:
[[33 14]
 [16 37]]

sensitivity / recall (TPR):
69.811
specificity (TNR):
70.213
Correct classification rate (CCR):
70.012
R-squared value %F:
-0.204
MAE value %F:
0.3
R-squared value logK(%F):
-0.204
MAE value logK(%F):
4.353
In [193]:
from imblearn.ensemble import EasyEnsembleClassifier
easy = EasyEnsembleClassifier(n_estimators=200, random_state=RANDOM_STATE)
easy.fit(X_train, y_train)  

y_pred = easy.predict(X_test)
class_scores(y_test, y_pred) 
________________________________________________________________________________
accuracy:   0.670
*** F1 macro avg score:   0.670
*** AUC for ROC = 0.674

classification report:
              precision    recall  f1-score   support

     Postive       0.62      0.74      0.68        47
    Negative       0.73      0.60      0.66        53

    accuracy                           0.67       100
   macro avg       0.68      0.67      0.67       100
weighted avg       0.68      0.67      0.67       100

confusion matrix:
[[35 12]
 [21 32]]

sensitivity / recall (TPR):
60.377
specificity (TNR):
74.468
Correct classification rate (CCR):
67.423
R-squared value %F:
-0.325
MAE value %F:
0.33
R-squared value logK(%F):
-0.325
MAE value logK(%F):
4.788
In [218]:
from imblearn.metrics import sensitivity_score, specificity_score
CCR = lambda y, y_pred: (sensitivity_score(y, y_pred) + specificity_score(y, y_pred)) / 2.
In [224]:
def plot_gridsearchCV(results, scoring, param='param_min_samples_split'):
    plt.figure(figsize=(13, 13))
    plt.title("GridSearchCV evaluating using multiple scorers simultaneously",
              fontsize=16)

    plt.xlabel(param)
    plt.ylabel("Score")

    ax = plt.gca()
    ax.set_xlim(0,1010)
    ax.set_ylim(0.50, 1)

    # Get the regular numpy array from the MaskedArray
    X_axis = np.array(results[param].data, dtype=float)

    for scorer, color in zip(sorted(scoring), ['g', 'k', 'r', 'b']):
        for sample, style in (('train', '--'), ('test', '-')):
            sample_score_mean = results['mean_%s_%s' % (sample, scorer)]
            sample_score_std = results['std_%s_%s' % (sample, scorer)]
            ax.fill_between(X_axis, sample_score_mean - sample_score_std,
                            sample_score_mean + sample_score_std,
                            alpha=0.1 if sample == 'test' else 0, color=color)
            ax.plot(X_axis, sample_score_mean, style, color=color,
                    alpha=1 if sample == 'test' else 0.7,
                    label="%s (%s)" % (scorer, sample))

        best_index = np.nonzero(results['rank_test_%s' % scorer] == 1)[0][0]
        best_score = results['mean_test_%s' % scorer][best_index]

        # Plot a dotted vertical line at the best score for that scorer marked by x
        ax.plot([X_axis[best_index], ] * 2, [0, best_score],
                linestyle='-.', color=color, marker='x', markeredgewidth=3, ms=8)

        # Annotate the best score for that scorer
        ax.annotate("%0.2f" % best_score,
                    (X_axis[best_index], best_score + 0.005))

    plt.legend(loc="best")
    plt.grid(False)
    plt.show()
In [225]:
from sklearn.model_selection import GridSearchCV
from sklearn.calibration import CalibratedClassifierCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import BaggingClassifier

from imblearn.metrics import sensitivity_score, specificity_score
from sklearn.metrics import r2_score, mean_absolute_error
CCR = lambda y, y_pred: (sensitivity_score(y, y_pred) + specificity_score(y, y_pred)) / 2.

# The scorers can be either be one of the predefined metric strings or a scorer
# callable, like the one returned by make_scorer
scoring = {'AUC': 'roc_auc', 'Sensitivity': make_scorer(sensitivity_score), 'Specificity': make_scorer(specificity_score), 'CCR': make_scorer(CCR)}
In [212]:
calibrated_forest = CalibratedClassifierCV(
   base_estimator=RandomForestClassifier(max_depth=10, random_state=RANDOM_STATE))
param_grid = {
   'base_estimator__n_estimators': [ 10, 20, 50, 100, 200, 500, 1000]}
search = GridSearchCV(calibrated_forest, param_grid, cv=5, scoring=scoring, refit='AUC', return_train_score=True)
search.fit(X, y)
results = search.cv_results_
In [219]:
plot_gridsearchCV(results, scoring, 'param_base_estimator__n_estimators')
In [232]:
calibrated_forest = CalibratedClassifierCV(
   base_estimator=BaggingClassifier(max_samples=1.0, max_features=0.4, random_state=RANDOM_STATE))
param_grid = {
   'base_estimator__n_estimators': [ 10, 20, 50, 100, 200, 500, 1000]}
search = GridSearchCV(calibrated_forest, param_grid, cv=5, scoring=scoring, refit='AUC', return_train_score=True)
search.fit(X, y)
results = search.cv_results_
In [233]:
plot_gridsearchCV(results, scoring, 'param_base_estimator__n_estimators')
In [303]:
descriptors = df.apply(lambda x: QED.properties(x['ROMol']), axis=1).apply(pd.Series)
descriptors.columns = ['MW', 'ALOGP', 'HBA', 'HBD', 'PSA', 'ROTB', 'AROM', 'ALERTS']
In [304]:
descriptors.head()
Out[304]:
MW ALOGP HBA HBD PSA ROTB AROM ALERTS
0 258.234 -2.8243 7.0 5.0 156.85 3.0 1.0 0.0
1 181.213 -0.5996 4.0 2.0 83.47 4.0 0.0 2.0
2 645.608 -8.5645 19.0 14.0 321.17 9.0 0.0 1.0
3 336.432 2.3655 5.0 3.0 87.66 10.0 1.0 1.0
4 203.238 -1.2357 4.0 0.0 66.43 5.0 0.0 2.0
In [305]:
# Compute the correlation matrix
corr = descriptors.corr()

# Generate a mask for the upper triangle
mask = np.triu(np.ones_like(corr, dtype=np.bool))

# Set up the matplotlib figure
f, ax = plt.subplots(figsize=(16, 16))

# Generate a custom diverging colormap
cmap = sns.diverging_palette(220, 10, as_cmap=True)

# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(corr, mask=mask, cmap=cmap, vmax=.3, center=0,
            square=True, linewidths=.5, cbar_kws={"shrink": .5})
Out[305]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f003ed41be0>
In [306]:
cols_to_drop = ['ID', 'Name', 'Smiles', 'ROMol', 'QED', 'mol-sentence']
predY = df['PercentF'].astype('float32')
classY = df['Class']
X_df = df.drop(cols_to_drop, axis=1)
X_df = pd.concat([descriptors, X_df], axis=1)
In [307]:
X_df.head()
Out[307]:
MW ALOGP HBA HBD PSA ROTB AROM ALERTS PercentF Class mol2vec-000 mol2vec-001 mol2vec-002 mol2vec-003 mol2vec-004 mol2vec-005 mol2vec-006 mol2vec-007 mol2vec-008 mol2vec-009 mol2vec-010 mol2vec-011 mol2vec-012 mol2vec-013 mol2vec-014 mol2vec-015 mol2vec-016 mol2vec-017 mol2vec-018 mol2vec-019 mol2vec-020 mol2vec-021 mol2vec-022 mol2vec-023 mol2vec-024 mol2vec-025 mol2vec-026 mol2vec-027 mol2vec-028 mol2vec-029 mol2vec-030 mol2vec-031 mol2vec-032 mol2vec-033 mol2vec-034 mol2vec-035 mol2vec-036 mol2vec-037 mol2vec-038 mol2vec-039 ... mol2vec-250 mol2vec-251 mol2vec-252 mol2vec-253 mol2vec-254 mol2vec-255 mol2vec-256 mol2vec-257 mol2vec-258 mol2vec-259 mol2vec-260 mol2vec-261 mol2vec-262 mol2vec-263 mol2vec-264 mol2vec-265 mol2vec-266 mol2vec-267 mol2vec-268 mol2vec-269 mol2vec-270 mol2vec-271 mol2vec-272 mol2vec-273 mol2vec-274 mol2vec-275 mol2vec-276 mol2vec-277 mol2vec-278 mol2vec-279 mol2vec-280 mol2vec-281 mol2vec-282 mol2vec-283 mol2vec-284 mol2vec-285 mol2vec-286 mol2vec-287 mol2vec-288 mol2vec-289 mol2vec-290 mol2vec-291 mol2vec-292 mol2vec-293 mol2vec-294 mol2vec-295 mol2vec-296 mol2vec-297 mol2vec-298 mol2vec-299
0 258.234 -2.8243 7.0 5.0 156.85 3.0 1.0 0.0 10.0 0 5.569553 -3.508770 -6.244288 5.889616 -1.266892 -3.958316 -10.726061 1.087396 0.765944 3.948586 -6.231452 -4.724694 1.894167 5.538465 -3.771910 -0.929421 0.159046 -4.833354 -6.800712 10.953048 3.669181 6.670616 11.887835 9.586557 -9.759455 -2.053182 -4.095988 0.343914 -1.182404 -3.112049 10.556608 -4.206781 -5.059879 -2.508824 2.807159 -0.843330 -1.256958 -2.913525 12.829318 6.306439 ... -11.783781 9.040501 9.574967 -3.682410 -15.312244 7.763779 3.385602 -0.282191 3.674171 -8.455920 0.822764 1.986318 2.576785 8.436736 3.072675 -10.695773 -5.662928 -2.369882 -3.476105 -4.092453 3.857708 3.705558 2.687113 2.462228 -2.661985 -10.343685 -8.095614 4.832535 1.800945 5.311703 -4.275393 -6.753325 -14.168078 -10.244950 -1.247408 4.727633 -6.435288 -1.939992 -1.084997 3.010950 -1.587761 9.543814 3.639564 1.137346 -6.341029 -3.649438 -6.683237 -8.796172 -13.052456 3.266420
1 181.213 -0.5996 4.0 2.0 83.47 4.0 0.0 2.0 11.0 0 0.406591 -0.862599 -0.939995 1.164309 1.466046 -0.642170 -7.448545 1.736181 4.776033 2.029239 -2.281273 -1.312205 -1.264534 4.512798 -2.351368 1.330068 2.290367 -2.327096 -3.163145 5.543420 -0.382270 4.606211 9.746282 4.537701 -6.197850 -2.856963 -3.084367 -2.159275 1.536228 -1.149428 2.132550 -0.733628 -3.376189 -1.771732 1.128453 -1.951368 -2.107461 1.293715 6.300990 4.409171 ... -7.924894 2.748025 2.905460 -2.996611 -5.352388 2.751185 3.358783 -4.225054 5.338570 -4.449970 -1.994026 2.601846 3.836283 6.505186 1.687100 -4.946456 -2.174414 -0.667827 1.473974 -5.721463 1.157199 2.592448 0.135769 3.644752 -3.019293 -5.439225 -0.903813 -0.280284 -0.069495 0.208698 -4.216588 -3.261138 -9.058621 -3.249794 0.463474 1.556273 -1.200347 -0.332860 -3.310076 -0.173647 0.348131 3.798353 3.605726 0.840295 -4.005333 -0.904105 -3.531618 -5.452664 -6.069296 -1.476137
2 645.608 -8.5645 19.0 14.0 321.17 9.0 0.0 1.0 2.0 0 7.495996 -10.775181 -18.471792 9.328063 -2.331269 -18.298042 -28.327505 4.498918 -8.643591 3.430152 -14.472315 -4.259753 21.511595 13.861922 -7.862411 -0.388533 -7.470189 -7.943233 -17.950714 18.455746 15.148751 24.349569 28.305340 20.877964 -20.434948 -3.514576 -6.044527 3.487448 -7.605060 -12.837484 24.317467 -6.451899 -12.126458 -11.029142 0.083344 -7.948068 -6.057978 -4.452465 34.231270 18.262081 ... -29.976904 29.419573 27.546228 -4.300668 -39.698620 21.187691 8.103436 2.837830 12.491193 -18.682749 4.830894 6.739665 0.985432 16.328875 7.279101 -25.315340 -5.638966 -6.691320 -4.365877 -8.003871 17.895763 14.207617 23.186853 -0.489443 -9.554961 -28.975965 -20.494970 15.271966 0.307983 16.753927 -9.837640 -13.661139 -39.439987 -31.803679 -1.562862 23.051231 -18.637310 -5.299712 -3.953747 6.623796 -8.590086 28.358700 -1.265912 3.456692 -18.159924 -20.828270 -25.539047 -22.109501 -34.757881 8.464603
3 336.432 2.3655 5.0 3.0 87.66 10.0 1.0 1.0 37.0 0 1.091887 1.854251 -4.900371 7.193990 1.563223 -1.963947 -18.288513 2.452731 10.435803 1.950667 -8.363561 -1.838400 -0.970261 6.558493 -3.511793 -0.082776 3.361502 -5.248492 -6.289870 13.695190 1.414803 11.211549 18.760096 13.486296 -13.724027 -6.549903 -3.427528 -6.586008 1.816793 -2.434518 11.115354 -3.151559 -3.251606 -3.340417 2.648002 -2.288478 -2.642019 -2.381675 15.069605 6.336995 ... -15.862226 9.795337 4.405685 -5.558150 -13.292850 4.238952 5.485982 -4.252264 9.935269 -12.415733 -1.644810 5.150281 4.815377 11.964157 4.764205 -11.726566 -8.082496 2.454412 4.449689 -11.469698 3.415305 10.346655 -1.208327 6.706863 -2.987926 -9.882215 -3.658513 0.836102 1.988530 2.590983 -8.181289 -4.651784 -19.863239 -9.365655 0.687346 7.646068 -3.439802 0.973907 -5.419662 4.962143 1.484586 8.792830 8.744014 3.414705 -5.819517 -1.169135 -5.415274 -13.250051 -13.500473 -1.355382
4 203.238 -1.2357 4.0 0.0 66.43 5.0 0.0 2.0 10.0 0 3.205750 -2.312642 -4.117469 1.197931 3.834179 -1.328032 -12.317728 4.010719 3.174892 1.512536 -3.611534 -1.978931 0.185490 7.318982 1.638154 4.105317 5.458116 -4.732348 -4.114553 3.884579 4.520523 2.980761 14.172358 8.302854 -4.640127 -2.562643 -1.439694 -2.975726 7.404806 -6.717851 0.764437 0.772412 -1.930308 0.014512 -1.280382 1.107110 -1.804889 1.280459 7.326070 2.090013 ... -4.537293 6.940051 1.277370 -2.195449 -5.351538 3.301649 6.007098 -1.513198 4.415155 -4.669563 -1.298284 2.810960 4.510495 6.650487 0.408486 -9.343277 -2.912362 -0.129418 2.342479 -1.667931 4.641690 4.510093 3.199713 2.232408 -2.000668 -6.569041 -0.633564 -0.627894 1.491187 3.399392 -7.462275 -2.675481 -9.861426 -2.502748 3.187210 6.317572 -2.026812 2.891784 -0.426157 1.656910 -0.808890 2.536302 5.873821 1.676117 -3.370314 -1.502855 1.750698 -9.177335 -5.981993 0.820515

5 rows × 310 columns

In [308]:
X_ = np.concatenate((X, descriptors.values), axis=1)
X_.shape
Out[308]:
(995, 308)
In [226]:
calibrated_forest = CalibratedClassifierCV(
   base_estimator=RandomForestClassifier(max_depth=10, random_state=RANDOM_STATE))
param_grid = {
   'base_estimator__n_estimators': [ 10, 20, 50, 100, 200, 500, 1000]}
search = GridSearchCV(calibrated_forest, param_grid, cv=5, scoring=scoring, refit='AUC', return_train_score=True)
search.fit(X_, y)
results = search.cv_results_
In [227]:
plot_gridsearchCV(results, scoring, 'param_base_estimator__n_estimators')
In [228]:
calibrated_forest = CalibratedClassifierCV(
   base_estimator=BaggingClassifier(max_samples=1.0, max_features=0.4, random_state=RANDOM_STATE))
param_grid = {
   'base_estimator__n_estimators': [ 10, 20, 50, 100, 200, 500, 1000]}
search = GridSearchCV(calibrated_forest, param_grid, cv=5, scoring=scoring, refit='AUC', return_train_score=True)
search.fit(X_, y)
results = search.cv_results_
In [229]:
plot_gridsearchCV(results, scoring, 'param_base_estimator__n_estimators')
In [230]:
calibrated_forest = CalibratedClassifierCV(
   base_estimator=RandomForestClassifier(max_depth=10, random_state=RANDOM_STATE))
param_grid = {
   'base_estimator__n_estimators': [ 10, 20, 50, 100, 200, 500, 1000]}
search = GridSearchCV(calibrated_forest, param_grid, cv=5, scoring=scoring, refit='AUC', return_train_score=True)
search.fit(descriptors.values, y)
results = search.cv_results_
In [231]:
plot_gridsearchCV(results, scoring, 'param_base_estimator__n_estimators')
In [75]:
from sklearn.pipeline import Pipeline
from sklearn.feature_selection import SelectKBest
pipe = Pipeline([
   ('select', SelectKBest()),
   ('model', calibrated_forest)])
param_grid = {
   'select__k': [50, 100, 300],
   'model__base_estimator__n_estimators': [50, 100, 200]}
search = GridSearchCV(pipe, param_grid, cv=5, scoring=scoring, refit='AUC', return_train_score=True)
search.fit(X, y)
results = search.cv_results_
In [78]:
results
Out[78]:
{'mean_fit_time': array([0.52800746, 0.64606586, 0.92466302, 1.04188371, 1.28566489,
        1.86234818, 2.07451382, 2.54261332, 3.65762415]),
 'std_fit_time': array([0.0026333 , 0.00330789, 0.00486905, 0.00403843, 0.02098747,
        0.03018251, 0.00847932, 0.01076099, 0.02120539]),
 'mean_score_time': array([0.04030843, 0.040381  , 0.04040942, 0.07538357, 0.0757606 ,
        0.07648807, 0.14509172, 0.14481468, 0.1454565 ]),
 'std_score_time': array([2.66503234e-04, 1.09489682e-04, 4.22558443e-05, 1.67453141e-03,
        1.56222984e-03, 2.39671718e-03, 4.55729650e-04, 6.87388730e-04,
        9.24285647e-04]),
 'param_model__base_estimator__n_estimators': masked_array(data=[50, 50, 50, 100, 100, 100, 200, 200, 200],
              mask=[False, False, False, False, False, False, False, False,
                    False],
        fill_value='?',
             dtype=object),
 'param_select__k': masked_array(data=[50, 100, 300, 50, 100, 300, 50, 100, 300],
              mask=[False, False, False, False, False, False, False, False,
                    False],
        fill_value='?',
             dtype=object),
 'params': [{'model__base_estimator__n_estimators': 50, 'select__k': 50},
  {'model__base_estimator__n_estimators': 50, 'select__k': 100},
  {'model__base_estimator__n_estimators': 50, 'select__k': 300},
  {'model__base_estimator__n_estimators': 100, 'select__k': 50},
  {'model__base_estimator__n_estimators': 100, 'select__k': 100},
  {'model__base_estimator__n_estimators': 100, 'select__k': 300},
  {'model__base_estimator__n_estimators': 200, 'select__k': 50},
  {'model__base_estimator__n_estimators': 200, 'select__k': 100},
  {'model__base_estimator__n_estimators': 200, 'select__k': 300}],
 'split0_test_AUC': array([0.73388563, 0.72832896, 0.72054961, 0.73166296, 0.72994544,
        0.7182259 , 0.73247121, 0.72196403, 0.72378258]),
 'split1_test_AUC': array([0.75672125, 0.7441884 , 0.74691732, 0.75621589, 0.74459268,
        0.75277946, 0.74853447, 0.7413584 , 0.75560946]),
 'split2_test_AUC': array([0.7247827 , 0.74358197, 0.75459875, 0.73327269, 0.74145947,
        0.75399232, 0.73347483, 0.74145947, 0.75530625]),
 'split3_test_AUC': array([0.78552658, 0.7953305 , 0.79896907, 0.79260158, 0.79169193,
        0.79846372, 0.78896301, 0.793208  , 0.79775622]),
 'split4_test_AUC': array([0.70871235, 0.72650091, 0.73468769, 0.7162927 , 0.71588842,
        0.73509197, 0.71457449, 0.72235698, 0.73478876]),
 'mean_test_AUC': array([0.7419257 , 0.74758615, 0.75114449, 0.74600916, 0.74471559,
        0.75171068, 0.7436036 , 0.74406938, 0.75344865]),
 'std_test_AUC': array([0.02676414, 0.02499011, 0.02654756, 0.02655609, 0.02555996,
        0.02679405, 0.02510379, 0.02603398, 0.02529643]),
 'rank_test_AUC': array([9, 4, 3, 5, 6, 2, 8, 7, 1], dtype=int32),
 'split0_train_AUC': array([0.99987366, 0.99968415, 0.99993683, 0.99995578, 0.99986734,
        0.99995578, 0.99993683, 0.99991788, 0.99994946]),
 'split1_train_AUC': array([0.99991157, 1.        , 0.99994947, 0.99993684, 0.99998105,
        0.99998105, 0.99998737, 0.99997474, 1.        ]),
 'split2_train_AUC': array([0.9999621 , 1.        , 0.99989894, 0.99993684, 0.99991789,
        0.99993684, 0.99998737, 0.9999621 , 0.99997474]),
 'split3_train_AUC': array([0.99989262, 1.        , 0.99997474, 0.99993052, 0.99997474,
        0.99993052, 0.99990526, 1.        , 0.99995579]),
 'split4_train_AUC': array([0.99922942, 0.9999621 , 0.9999621 , 0.99957681, 0.99995579,
        0.99998105, 0.99953892, 0.99998105, 1.        ]),
 'mean_train_AUC': array([0.99977388, 0.99992925, 0.99994442, 0.99986736, 0.99993936,
        0.99995705, 0.99987115, 0.99996715, 0.999976  ]),
 'std_train_AUC': array([2.73816255e-04, 1.23425295e-04, 2.60121526e-05, 1.45518940e-04,
        4.22059800e-05, 2.12885677e-05, 1.69037773e-04, 2.75062197e-05,
        2.12900161e-05]),
 'split0_test_Sensitivity': array([0.68316832, 0.71287129, 0.68316832, 0.7029703 , 0.68316832,
        0.66336634, 0.7029703 , 0.7029703 , 0.69306931]),
 'split1_test_Sensitivity': array([0.73529412, 0.73529412, 0.7254902 , 0.74509804, 0.73529412,
        0.7745098 , 0.74509804, 0.7254902 , 0.75490196]),
 'split2_test_Sensitivity': array([0.71568627, 0.7254902 , 0.7254902 , 0.74509804, 0.7254902 ,
        0.7254902 , 0.73529412, 0.7254902 , 0.71568627]),
 'split3_test_Sensitivity': array([0.67647059, 0.70588235, 0.65686275, 0.7254902 , 0.67647059,
        0.67647059, 0.69607843, 0.67647059, 0.67647059]),
 'split4_test_Sensitivity': array([0.6372549 , 0.65686275, 0.69607843, 0.67647059, 0.62745098,
        0.69607843, 0.64705882, 0.64705882, 0.67647059]),
 'mean_test_Sensitivity': array([0.68957484, 0.70728014, 0.69741798, 0.71902543, 0.68957484,
        0.70718307, 0.70529994, 0.69549602, 0.70331974]),
 'std_test_Sensitivity': array([0.03382957, 0.02717174, 0.02617508, 0.02634402, 0.03860649,
        0.03963462, 0.03454282, 0.03021025, 0.02953128]),
 'rank_test_Sensitivity': array([8, 2, 6, 1, 8, 3, 4, 7, 5], dtype=int32),
 'split0_train_Sensitivity': array([0.99754902, 0.99509804, 1.        , 0.99754902, 0.99754902,
        1.        , 0.99754902, 1.        , 0.99754902]),
 'split1_train_Sensitivity': array([1.      , 1.      , 1.      , 0.995086, 1.      , 1.      ,
        1.      , 1.      , 1.      ]),
 'split2_train_Sensitivity': array([0.995086, 1.      , 1.      , 1.      , 1.      , 1.      ,
        1.      , 0.997543, 1.      ]),
 'split3_train_Sensitivity': array([0.995086, 1.      , 1.      , 0.995086, 1.      , 0.997543,
        1.      , 1.      , 1.      ]),
 'split4_train_Sensitivity': array([0.997543, 1.      , 1.      , 1.      , 1.      , 1.      ,
        1.      , 1.      , 1.      ]),
 'mean_train_Sensitivity': array([0.9970528 , 0.99901961, 1.        , 0.9975442 , 0.9995098 ,
        0.9995086 , 0.9995098 , 0.9995086 , 0.9995098 ]),
 'std_train_Sensitivity': array([0.00183898, 0.00196078, 0.        , 0.00219761, 0.00098039,
        0.0009828 , 0.00098039, 0.0009828 , 0.00098039]),
 'split0_test_Specificity': array([0.67346939, 0.67346939, 0.67346939, 0.67346939, 0.71428571,
        0.65306122, 0.65306122, 0.67346939, 0.69387755]),
 'split1_test_Specificity': array([0.68041237, 0.65979381, 0.65979381, 0.64948454, 0.64948454,
        0.60824742, 0.67010309, 0.6185567 , 0.64948454]),
 'split2_test_Specificity': array([0.58762887, 0.65979381, 0.69072165, 0.59793814, 0.67010309,
        0.68041237, 0.59793814, 0.67010309, 0.67010309]),
 'split3_test_Specificity': array([0.75257732, 0.74226804, 0.7628866 , 0.77319588, 0.75257732,
        0.7628866 , 0.7628866 , 0.75257732, 0.75257732]),
 'split4_test_Specificity': array([0.65979381, 0.65979381, 0.69072165, 0.65979381, 0.65979381,
        0.67010309, 0.64948454, 0.65979381, 0.68041237]),
 'mean_test_Specificity': array([0.67077635, 0.67902377, 0.69551862, 0.67077635, 0.6892489 ,
        0.67494214, 0.66669472, 0.67490006, 0.68929097]),
 'std_test_Specificity': array([0.05258445, 0.03206263, 0.03562835, 0.05723002, 0.03860751,
        0.05042723, 0.05380339, 0.04350078, 0.03480734]),
 'rank_test_Specificity': array([7, 4, 1, 7, 3, 5, 9, 6, 2], dtype=int32),
 'split0_train_Specificity': array([0.97938144, 0.99226804, 0.98969072, 0.98969072, 0.99226804,
        0.98969072, 0.99484536, 0.98969072, 0.98969072]),
 'split1_train_Specificity': array([0.99228792, 0.99485861, 0.98200514, 0.98971722, 0.99228792,
        0.99485861, 0.99228792, 0.99228792, 0.99485861]),
 'split2_train_Specificity': array([0.99228792, 0.99228792, 0.99485861, 0.99228792, 0.98971722,
        0.99485861, 0.98971722, 0.99485861, 0.99742931]),
 'split3_train_Specificity': array([0.99742931, 0.98971722, 0.99742931, 0.99228792, 0.99485861,
        0.99485861, 0.99228792, 1.        , 0.99742931]),
 'split4_train_Specificity': array([0.97686375, 0.98714653, 0.99228792, 0.98200514, 0.97943445,
        0.98714653, 0.98457584, 0.98457584, 0.99228792]),
 'mean_train_Specificity': array([0.98765007, 0.99125566, 0.99125434, 0.98919778, 0.98971325,
        0.99228262, 0.99074285, 0.99228262, 0.99433917]),
 'std_train_Specificity': array([0.00804198, 0.00262006, 0.00529494, 0.00377743, 0.00539044,
        0.0032559 , 0.00348394, 0.00514405, 0.00300611]),
 'split0_test_CCR': array([0.67831885, 0.69317034, 0.67831885, 0.68821984, 0.69872702,
        0.65821378, 0.67801576, 0.68821984, 0.69347343]),
 'split1_test_CCR': array([0.70785324, 0.69754397, 0.69264201, 0.69729129, 0.69238933,
        0.69137861, 0.70760057, 0.67202345, 0.70219325]),
 'split2_test_CCR': array([0.65165757, 0.69264201, 0.70810592, 0.67151809, 0.69779664,
        0.70295128, 0.66661613, 0.69779664, 0.69289468]),
 'split3_test_CCR': array([0.71452395, 0.7240752 , 0.70987467, 0.74934304, 0.71452395,
        0.71967859, 0.72948251, 0.71452395, 0.71452395]),
 'split4_test_CCR': array([0.64852436, 0.65832828, 0.69340004, 0.6681322 , 0.6436224 ,
        0.68309076, 0.64827168, 0.65342632, 0.67844148]),
 'mean_test_CCR': array([0.6801756 , 0.69315196, 0.6964683 , 0.69490089, 0.68941187,
        0.69106261, 0.68599733, 0.68519804, 0.69630536]),
 'std_test_CCR': array([0.02743862, 0.02092121, 0.01156393, 0.02924933, 0.02405508,
        0.0205081 , 0.02905244, 0.02102797, 0.01187733]),
 'rank_test_CCR': array([9, 4, 1, 3, 6, 5, 7, 8, 2], dtype=int32),
 'split0_train_CCR': array([0.98846523, 0.99368304, 0.99484536, 0.99361987, 0.99490853,
        0.99484536, 0.99619719, 0.99484536, 0.99361987]),
 'split1_train_CCR': array([0.99614396, 0.99742931, 0.99100257, 0.99240161, 0.99614396,
        0.99742931, 0.99614396, 0.99614396, 0.99742931]),
 'split2_train_CCR': array([0.99368696, 0.99614396, 0.99742931, 0.99614396, 0.99485861,
        0.99742931, 0.99485861, 0.9962008 , 0.99871465]),
 'split3_train_CCR': array([0.99625765, 0.99485861, 0.99871465, 0.99368696, 0.99742931,
        0.9962008 , 0.99614396, 1.        , 0.99871465]),
 'split4_train_CCR': array([0.98720338, 0.99357326, 0.99614396, 0.99100257, 0.98971722,
        0.99357326, 0.99228792, 0.99228792, 0.99614396]),
 'mean_train_CCR': array([0.99235143, 0.99513764, 0.99562717, 0.99337099, 0.99461153,
        0.99589561, 0.99512633, 0.99589561, 0.99692449]),
 'std_train_CCR': array([0.00382179, 0.00147685, 0.00264747, 0.00169707, 0.00262266,
        0.00150292, 0.0015064 , 0.00249477, 0.00190756])}
In [147]:
import xgboost as xgb


def rmse(predictions, targets):
    return np.sqrt(((predictions - targets) ** 2).mean())

param = {'max_depth':2, 'eta':1, 'objective':'binary:logistic' }
num_round = 2
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test)
#bst = xgb.train(param, dtrain, num_round)
# make prediction
#y_predicted = bst.predict(dtest)

res = xgb.cv(param, dtrain, num_boost_round=1000, nfold=5, seed=RANDOM_STATE, stratified=False,
             early_stopping_rounds=25, verbose_eval=10, show_stdv=True)
    
best_nrounds = res.shape[0] - 1
print(np.shape(X_train), np.shape(X_test), np.shape(y_train), np.shape(y_test))
gbdt = xgb.train(param, dtrain, best_nrounds)
y_predicted = gbdt.predict(dtest)

class_scores(y_test, y_predicted) 
[0]	train-error:0.33166+0.01282	test-error:0.36688+0.03266
[10]	train-error:0.14039+0.00897	test-error:0.38329+0.05562
[20]	train-error:0.05120+0.00586	test-error:0.37822+0.03610
(796, 300) (100, 300) (796,) (100,)
________________________________________________________________________________
accuracy:   0.650
*** F1 macro avg score:   0.649
*** AUC for ROC = 0.762

classification report:
              precision    recall  f1-score   support

     Postive       0.62      0.64      0.63        47
    Negative       0.67      0.66      0.67        53

    accuracy                           0.65       100
   macro avg       0.65      0.65      0.65       100
weighted avg       0.65      0.65      0.65       100

confusion matrix:
[[30 17]
 [18 35]]

sensitivity / recall (TPR):
66.038
specificity (TNR):
63.83
Correct classification rate (CCR):
64.934
In [67]:
# Setting refit='AUC', refits an estimator on the whole dataset with the
# parameter setting that has the best cross-validated AUC score.
# That estimator is made available at ``gs.best_estimator_`` along with
# parameters like ``gs.best_score_``, ``gs.best_params_`` and
# ``gs.best_index_``
gs = GridSearchCV(DecisionTreeClassifier(random_state=42),
                  param_grid={'min_samples_split': range(2, 403, 10)},
                  scoring=scoring, refit='AUC', return_train_score=True)
gs.fit(X, y)
results = gs.cv_results_
In [68]:
plot_gridsearchCV(results, scoring)

Prediction of various datasets

In [430]:
def input_test_data(in_file_pos, in_file_neg, out_file):

    df_pos = pd.read_csv(in_file_pos, delimiter=',', usecols=[0, 1, 2], names=['ID', 'Name', 'Smiles'], encoding='latin-1')  # Assume <tab> separated
    df_neg = pd.read_csv(in_file_neg, delimiter=',', usecols=[0, 1, 2], names=['ID', 'Name', 'Smiles'], encoding='latin-1')  # Assume <tab> separated
    df_neg['Class'] = 0
    print('Negative dataframe columns')
    print(df_neg.columns)

    df_pos['Class'] = 1
    print('Postive dataframe columns')
    print(df_pos.columns)

    df = pd.concat([df_neg, df_pos])
    df.reset_index(drop=True, inplace=True)
    
    df_input = df
    
    model_path = './models/model_300dim.pkl'

    X, df = featurize(df, out_file, model_path, 2, uncommon='UNK')

    descriptors = df.apply(lambda x: QED.properties(x['ROMol']), axis=1).apply(pd.Series)
    descriptors.columns = ['MW', 'ALOGP', 'HBA', 'HBD', 'PSA', 'ROTB', 'AROM', 'ALERTS']

    cols_to_drop = ['ID', 'Name', 'Smiles', 'Class', 'ROMol', 'QED', 'mol-sentence']

    classY = df['Class']
    df = df.drop(cols_to_drop, axis=1)
    df = pd.concat([descriptors, df], axis=1)
    X = df.values

    return X, df, classY, df_input
In [447]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.manifold import TSNE
from sklearn.metrics import plot_confusion_matrix

import umap.umap_ as UMAP

def plot_scatter(test_X, test_classY):
    #pca_model = PCA(n_components=30)
    scaler = StandardScaler()
    pca_model = PCA(n_components=2)
    umap_model = UMAP.UMAP()

    pca = pca_model.fit_transform(scaler.fit_transform(test_X))
    umap = umap_model.fit_transform(scaler.fit_transform(test_X))

    df_vec = pd.DataFrame()
    df_vec['PCA-c1'] = pca.T[0]
    df_vec['PCA-c2'] = pca.T[1]
    df_vec['UMAP-c1'] = umap.T[0]
    df_vec['UMAP-c2'] = umap.T[1]
    df_vec['Class'] = ['Positive dataset'  if x == 1 
                   else 'Negative dataset' for x in test_classY.tolist()]

    f, ax = plt.subplots(figsize=(8, 8))
    sns.scatterplot('UMAP-c1','UMAP-c2', hue='Class', data=df_vec, legend='full', palette=palette)
    plt.show()
    return df_vec
In [350]:
def run_classifier (clf, X_train, X_dev, y_train, y_test):
    """
    clf - the classifier
    returns: the predictions and the probabilities
    """
    print('_' * 80)
    print("Training: ")
    print(clf)
    t0 = time()
    clf.fit(X_train, y_train)
    train_time = time() - t0
    print("train time: %0.3fs" % train_time)

    pred = clf.predict(X_dev)
    prob = clf.predict_proba(X_dev)[:,1]
    
    df_vec = pd.DataFrame()
    df_vec['Class'] = ['Positive dataset'  if x == 1 
                   else 'Negative dataset' for x in y_test.tolist()]
    df_vec['Prediction'] = pred
    df_vec['Probability'] = prob
    
    return pred, prob, df_vec
    
In [351]:
def plot_violin(test_classY, pred, prob):
    # Set up the matplotlib figure
    f, ax = plt.subplots(figsize=(8, 8))
    
    df_vec = pd.DataFrame()
    df_vec['Class'] = ['Positive dataset'  if x == 1 
                   else 'Negative dataset' for x in test_classY.tolist()]
    df_vec['Prediction'] = pred
    df_vec['Probability'] = prob

    sns.violinplot(x="Class", y="Probability", data=df_vec, palette="muted")
    plt.show()
In [352]:
def plot_stripplot(test_classY, pred, prob):
    # Set up the matplotlib figure
    f, ax = plt.subplots(figsize=(8, 8))
    
    df_vec = pd.DataFrame()
    df_vec['Class'] = ['Positive dataset'  if x == 1 
                   else 'Negative dataset' for x in test_classY.tolist()]
    df_vec['Prediction'] = pred
    df_vec['Probability'] = prob

    sns.stripplot(x="Class", y="Probability", data=df_vec, palette="muted", jitter=0.10)
    plt.show()
In [353]:
clf = BaggingClassifier(n_estimators=50, max_samples=1.0, max_features=0.4, random_state=RANDOM_STATE)

clf = RandomForestClassifier(n_estimators=500, max_depth=10, max_features='auto', random_state=RANDOM_STATE)

df_scat = plot_scatter(X_test, y_test)

pred, prob, df_vec = run_classifier (clf, X_train, X_test, y_train, y_test)


plt.subplots(figsize=(8, 8))
sns.violinplot(x="Class", y="Probability", data=df_vec, palette="muted")
plt.show()

plt.subplots(figsize=(8, 8))
sns.stripplot(x="Class", y="Probability", data=df_vec, palette="muted", jitter=0.10)
plt.show()
________________________________________________________________________________
Training: 
RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=10, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=500,
                       n_jobs=None, oob_score=False, random_state=458,
                       verbose=0, warm_start=False)
train time: 2.258s

Case One - HSDB vs ChemBL 4-phases

In [380]:
in_file_pos = './BA-Chembl-4-phases-smiles-no-repeat.csv'
in_file_neg = './hsbd-smiles-no-repeat.csv'
out_file = 'BA-vectors.csv'

test_X, test_df, test_classY = input_test_data(in_file_pos, in_file_neg, out_file)
test_X = test_X[:,8:]
Negative dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Postive dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Loading molecules.
RDKit ERROR: [15:42:13] SMILES Parse Error: syntax error for input: 'Smiles'
RDKit ERROR: [15:42:13] SMILES Parse Error: syntax error for input: 'nan'
RDKit ERROR: [15:42:13] SMILES Parse Error: syntax error for input: 'nan'
RDKit ERROR: [15:42:13] SMILES Parse Error: syntax error for input: 'nan'
RDKit ERROR: [15:42:13] SMILES Parse Error: syntax error for input: 'nan'
RDKit ERROR: [15:42:13] SMILES Parse Error: syntax error for input: 'nan'
RDKit ERROR: [15:42:13] SMILES Parse Error: syntax error for input: 'nan'
RDKit ERROR: [15:42:14] Explicit valence for atom # 1 Cl, 7, is greater than permitted
RDKit ERROR: [15:42:14] SMILES Parse Error: syntax error for input: 'Smiles'
Keeping only molecules that can be processed by RDKit.
Featurizing molecules.
(7365, 300)
In [381]:
clf = BaggingClassifier(n_estimators=50, max_samples=1.0, max_features=0.4, random_state=RANDOM_STATE)
clf = RandomForestClassifier(n_estimators=500, max_features='auto', random_state=RANDOM_STATE)

df_scat = plot_scatter(test_X, test_classY)

pred, prob, df_vec = run_classifier (clf, X, test_X, y, test_classY)


plt.subplots(figsize=(8, 8))
sns.violinplot(x="Class", y="Probability", data=df_vec, palette="muted")
plt.show()

plt.subplots(figsize=(8, 8))
sns.stripplot(x="Class", y="Probability", data=df_vec, palette="muted", jitter=0.10)
plt.show()
________________________________________________________________________________
Training: 
RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=500,
                       n_jobs=None, oob_score=False, random_state=458,
                       verbose=0, warm_start=False)
train time: 3.099s
In [382]:
title = 'Confusion matrix with input classes'
disp = plot_confusion_matrix(clf, test_X, test_classY,
                             display_labels=[0,1],
                             cmap=plt.cm.Blues)
disp.ax_.set_title(title)

print()
print(disp.confusion_matrix)
plt.show()
[[1075 2448]
 [2361 1481]]

Case Prediction - BA curated

In [463]:
def input_single_test_data(in_file, out_file):

    df = pd.read_csv(in_file, delimiter=',', usecols=[0, 1], names=['Name', 'Smiles'], encoding='latin-1')  # Assume <tab> separated
    
    df_input = df
    
    model_path = './models/model_300dim.pkl'

    X, df = featurize(df, out_file, model_path, 2, uncommon='UNK')

    descriptors = df.apply(lambda x: QED.properties(x['ROMol']), axis=1).apply(pd.Series)
    descriptors.columns = ['MW', 'ALOGP', 'HBA', 'HBD', 'PSA', 'ROTB', 'AROM', 'ALERTS']

    cols_to_drop = ['Name', 'Smiles', 'ROMol', 'QED', 'mol-sentence']

    df = df.drop(cols_to_drop, axis=1)
    df = pd.concat([descriptors, df], axis=1)
    X = df.values

    return X, df, df_input
In [464]:
in_file = './All_curated.csv'
out_file = 'All_curated-vectors.csv'

test_X, test_df, df_input = input_single_test_data(in_file, out_file)
test_X = test_X[:,8:]
Loading molecules.
RDKit ERROR: [18:27:32] Explicit valence for atom # 2 O, 3, is greater than permitted
RDKit ERROR: [18:27:32] Explicit valence for atom # 0 N, 4, is greater than permitted
RDKit ERROR: [18:27:32] Explicit valence for atom # 0 N, 4, is greater than permitted
RDKit ERROR: [18:27:33] Explicit valence for atom # 0 N, 4, is greater than permitted
RDKit ERROR: [18:27:33] Explicit valence for atom # 13 Cl, 5, is greater than permitted
RDKit ERROR: [18:27:33] SMILES Parse Error: syntax error for input: 'OS(O)(O)C1=CC=C(C=C1)C-1=C2\C=CC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC=C(C=C1)S(O)(O)O)C1=CC=C(C=C1)S([O-])([O-])[O-])\C1=CC=C(C=C1)S(O)(O)[O-]'
RDKit ERROR: [18:27:33] Explicit valence for atom # 14 N, 5, is greater than permitted
RDKit ERROR: [18:27:33] Explicit valence for atom # 19 O, 3, is greater than permitted
RDKit ERROR: [18:27:33] Explicit valence for atom # 2 O, 3, is greater than permitted
RDKit ERROR: [18:27:33] Explicit valence for atom # 6 N, 4, is greater than permitted
RDKit ERROR: [18:27:33] Explicit valence for atom # 11 N, 4, is greater than permitted
RDKit ERROR: [18:27:33] Explicit valence for atom # 0 O, 3, is greater than permitted
RDKit ERROR: [18:27:33] Explicit valence for atom # 6 Be, 4, is greater than permitted
RDKit ERROR: [18:27:33] Explicit valence for atom # 3 N, 4, is greater than permitted
RDKit ERROR: [18:27:33] Explicit valence for atom # 4 F, 2, is greater than permitted
RDKit ERROR: [18:27:34] Explicit valence for atom # 13 Be, 3, is greater than permitted
RDKit ERROR: [18:27:34] Explicit valence for atom # 2 N, 4, is greater than permitted
RDKit ERROR: [18:27:34] SMILES Parse Error: syntax error for input: 'OC1=CC=CC(=C1)C-1=C2\CCC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC(O)=CC=C1)C1=CC(O)=CC=C1)\C1=CC(O)=CC=C1'
RDKit ERROR: [18:27:35] Explicit valence for atom # 1 Cl, 4, is greater than permitted
RDKit ERROR: [18:27:35] Explicit valence for atom # 0 N, 4, is greater than permitted
RDKit ERROR: [18:27:35] Explicit valence for atom # 0 Cl, 2, is greater than permitted
RDKit ERROR: [18:27:35] Explicit valence for atom # 5 K, 2, is greater than permitted
RDKit ERROR: [18:27:35] Explicit valence for atom # 0 Mg, 4, is greater than permitted
Keeping only molecules that can be processed by RDKit.
Featurizing molecules.
(12798, 300)
In [465]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.manifold import TSNE
from sklearn.metrics import plot_confusion_matrix

import umap.umap_ as UMAP

def plot_single_scatter(test_X):
    #pca_model = PCA(n_components=30)
    scaler = StandardScaler()
    pca_model = PCA(n_components=2)
    umap_model = UMAP.UMAP()

    pca = pca_model.fit_transform(scaler.fit_transform(test_X))
    umap = umap_model.fit_transform(scaler.fit_transform(test_X))

    df_vec = pd.DataFrame()
    df_vec['PCA-c1'] = pca.T[0]
    df_vec['PCA-c2'] = pca.T[1]
    df_vec['UMAP-c1'] = umap.T[0]
    df_vec['UMAP-c2'] = umap.T[1]

    f, ax = plt.subplots(figsize=(8, 8))
    sns.scatterplot('UMAP-c1','UMAP-c2',data=df_vec, legend='full', palette=palette)
    plt.show()
    return df_vec
In [467]:
def run_single_classifier (clf, X_train, X_dev, y_train):
    """
    clf - the classifier
    returns: the predictions and the probabilities
    """
    print('_' * 80)
    print("Training: ")
    print(clf)
    t0 = time()
    clf.fit(X_train, y_train)
    train_time = time() - t0
    print("train time: %0.3fs" % train_time)

    pred = clf.predict(X_dev)
    prob = clf.predict_proba(X_dev)[:,1]
    
    df_vec = pd.DataFrame()
    df_vec['Class'] = ['Positive dataset'  if x == 1 
                   else 'Negative dataset' for x in pred.tolist()]
    df_vec['Prediction'] = pred
    df_vec['Probability'] = prob
    
    return pred, prob, df_vec
    
In [468]:
clf = RandomForestClassifier(n_estimators=500, max_features='auto', random_state=RANDOM_STATE)

df_scat = plot_single_scatter(test_X)

pred, prob, df_vec = run_single_classifier (clf, X, test_X, y)


plt.subplots(figsize=(8, 8))
sns.violinplot(x="Class", y="Probability", data=df_vec, palette="muted")
plt.show()

plt.subplots(figsize=(8, 8))
sns.stripplot(x="Class", y="Probability", data=df_vec, palette="muted", jitter=0.10)
plt.show()
________________________________________________________________________________
Training: 
RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=500,
                       n_jobs=None, oob_score=False, random_state=458,
                       verbose=0, warm_start=False)
train time: 3.085s
In [451]:
df_scat.columns
Out[451]:
Index(['PCA-c1', 'PCA-c2', 'UMAP-c1', 'UMAP-c2', 'Class'], dtype='object')
In [452]:
df_vec['ID'] = df_input['ID']
df_vec['Name'] = df_input['Name']
df_vec['Smiles'] = df_input['Smiles']
df_vec['x-coords'] = df_scat['UMAP-c1']
df_vec['y-coords'] = df_scat['UMAP-c2']
In [453]:
df_vec.head(15)
Out[453]:
Class Prediction Probability ID Name Smiles x-coords y-coords
0 Negative dataset 0 0.280 DB00006 Bivalirudin CC[C@H](C)[C@H](NC(=O)[C@H](CCC(O)=O)NC(=O)[C@... 9.247614 -0.817957
1 Negative dataset 0 0.190 DB00007 Leuprolide CCNC(=O)[C@@H]1CCCN1C(=O)[C@H](CCCNC(N)=N)NC(=... 8.665055 -0.751155
2 Negative dataset 0 0.184 DB00014 Goserelin CC(C)C[C@H](NC(=O)[C@@H](COC(C)(C)C)NC(=O)[C@H... 8.635137 -0.730543
3 Negative dataset 0 0.050 DB00035 Desmopressin NC(=O)CC[C@@H]1NC(=O)[C@H](CC2=CC=CC=C2)NC(=O)... 5.563773 -0.410755
4 Negative dataset 0 0.272 DB00050 Cetrorelix CC(C)C[C@H](NC(=O)[C@@H](CCCNC(N)=O)NC(=O)[C@H... 8.955009 -0.810468
5 Negative dataset 0 0.062 DB00080 Daptomycin CCCCCCCCCC(=O)N[C@@H](CC1=CNC2=C1C=CC=C2)C(=O)... 6.157721 -0.374207
6 Negative dataset 0 0.110 DB00091 Cyclosporine CC[C@@H]1NC(=O)[C@H]([C@H](O)[C@H](C)C\C=C\C)N... 5.178070 0.468352
7 Negative dataset 0 0.048 DB00104 Octreotide [H][C@]1(NC(=O)[C@H](CCCCN)NC(=O)[C@@H](CC2=CN... 5.494781 -0.333457
8 Negative dataset 0 0.252 DB00115 Cyanocobalamin C[C@H](CNC(=O)CC[C@]1(C)[C@@H](CC(N)=O)[C@H]2N... 4.341208 1.312017
9 Negative dataset 1 0.530 DB00158 Folic acid NC1=NC(=O)C2=NC(CNC3=CC=C(C=C3)C(=O)N[C@@H](CC... 3.415751 -1.807132
10 Negative dataset 0 0.342 DB00163 Vitamin E CC(C)CCC[C@@H](C)CCC[C@@H](C)CCC[C@]1(C)CCC2=C... -0.925311 -5.645232
11 Negative dataset 0 0.132 DB00199 Erythromycin CC[C@H]1OC(=O)[C@H](C)[C@@H](O[C@H]2C[C@@](C)(... 4.228536 4.943330
12 Negative dataset 0 0.252 DB00200 Hydroxocobalamin [N+]1=2[Co-3]345([N+]6=C7[C@H]([C@@](CC(=O)N)(... 4.329038 1.318275
13 Negative dataset 0 0.110 DB00207 Azithromycin CC[C@H]1OC(=O)[C@H](C)[C@@H](O[C@H]2C[C@@](C)(... 4.254687 4.930093
14 Negative dataset 0 0.208 DB00278 Argatroban C[C@@H]1CCN([C@H](C1)C(O)=O)C(=O)[C@H](CCCNC(N... -0.004274 -0.422700
In [454]:
df_vec.to_csv('probabilities_BA_prediction_curated.csv',index=False)

Case Two - BA curated

In [448]:
in_file_pos = './BA-curated-descriptors.csv'
in_file_neg = './non-BA-curated-descriptors.csv'
out_file = 'BA-curated-descriptors-vectors.csv'

test_X, test_df, test_classY, df_input = input_test_data(in_file_pos, in_file_neg, out_file)
test_X = test_X[:,8:]
Negative dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Postive dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Loading molecules.
RDKit ERROR: [17:20:26] SMILES Parse Error: syntax error for input: 'nan'
RDKit ERROR: [17:20:26] SMILES Parse Error: syntax error for input: 'OC1=CC=CC(=C1)C-1=C2\CCC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC(O)=CC=C1)C1=CC(O)=CC=C1)\C1=CC(O)=CC=C1'
RDKit ERROR: [17:20:26] SMILES Parse Error: syntax error for input: 'OS(O)(O)C1=CC=C(C=C1)C-1=C2\C=CC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC=C(C=C1)S(O)(O)O)C1=CC=C(C=C1)S([O-])([O-])[O-])\C1=CC=C(C=C1)S(O)(O)[O-]'
RDKit ERROR: [17:20:26] Explicit valence for atom # 19 O, 3, is greater than permitted
Keeping only molecules that can be processed by RDKit.
Featurizing molecules.
(2336, 300)
In [449]:
clf = RandomForestClassifier(n_estimators=500, max_features='auto', random_state=RANDOM_STATE)

df_scat = plot_scatter(test_X, test_classY)

pred, prob, df_vec = run_classifier (clf, X, test_X, y, test_classY)


plt.subplots(figsize=(8, 8))
sns.violinplot(x="Class", y="Probability", data=df_vec, palette="muted")
plt.show()

plt.subplots(figsize=(8, 8))
sns.stripplot(x="Class", y="Probability", data=df_vec, palette="muted", jitter=0.10)
plt.show()
________________________________________________________________________________
Training: 
RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=500,
                       n_jobs=None, oob_score=False, random_state=458,
                       verbose=0, warm_start=False)
train time: 3.099s
In [450]:
title = 'Confusion matrix with input classes'
disp = plot_confusion_matrix(clf, test_X, test_classY,
                             display_labels=[0,1],
                             cmap=plt.cm.Blues)
disp.ax_.set_title(title)

print()
print(disp.confusion_matrix)
plt.show()
[[1390   57]
 [ 388  501]]
In [451]:
df_scat.columns
Out[451]:
Index(['PCA-c1', 'PCA-c2', 'UMAP-c1', 'UMAP-c2', 'Class'], dtype='object')
In [452]:
df_vec['ID'] = df_input['ID']
df_vec['Name'] = df_input['Name']
df_vec['Smiles'] = df_input['Smiles']
df_vec['x-coords'] = df_scat['UMAP-c1']
df_vec['y-coords'] = df_scat['UMAP-c2']
In [453]:
df_vec.head(15)
Out[453]:
Class Prediction Probability ID Name Smiles x-coords y-coords
0 Negative dataset 0 0.280 DB00006 Bivalirudin CC[C@H](C)[C@H](NC(=O)[C@H](CCC(O)=O)NC(=O)[C@... 9.247614 -0.817957
1 Negative dataset 0 0.190 DB00007 Leuprolide CCNC(=O)[C@@H]1CCCN1C(=O)[C@H](CCCNC(N)=N)NC(=... 8.665055 -0.751155
2 Negative dataset 0 0.184 DB00014 Goserelin CC(C)C[C@H](NC(=O)[C@@H](COC(C)(C)C)NC(=O)[C@H... 8.635137 -0.730543
3 Negative dataset 0 0.050 DB00035 Desmopressin NC(=O)CC[C@@H]1NC(=O)[C@H](CC2=CC=CC=C2)NC(=O)... 5.563773 -0.410755
4 Negative dataset 0 0.272 DB00050 Cetrorelix CC(C)C[C@H](NC(=O)[C@@H](CCCNC(N)=O)NC(=O)[C@H... 8.955009 -0.810468
5 Negative dataset 0 0.062 DB00080 Daptomycin CCCCCCCCCC(=O)N[C@@H](CC1=CNC2=C1C=CC=C2)C(=O)... 6.157721 -0.374207
6 Negative dataset 0 0.110 DB00091 Cyclosporine CC[C@@H]1NC(=O)[C@H]([C@H](O)[C@H](C)C\C=C\C)N... 5.178070 0.468352
7 Negative dataset 0 0.048 DB00104 Octreotide [H][C@]1(NC(=O)[C@H](CCCCN)NC(=O)[C@@H](CC2=CN... 5.494781 -0.333457
8 Negative dataset 0 0.252 DB00115 Cyanocobalamin C[C@H](CNC(=O)CC[C@]1(C)[C@@H](CC(N)=O)[C@H]2N... 4.341208 1.312017
9 Negative dataset 1 0.530 DB00158 Folic acid NC1=NC(=O)C2=NC(CNC3=CC=C(C=C3)C(=O)N[C@@H](CC... 3.415751 -1.807132
10 Negative dataset 0 0.342 DB00163 Vitamin E CC(C)CCC[C@@H](C)CCC[C@@H](C)CCC[C@]1(C)CCC2=C... -0.925311 -5.645232
11 Negative dataset 0 0.132 DB00199 Erythromycin CC[C@H]1OC(=O)[C@H](C)[C@@H](O[C@H]2C[C@@](C)(... 4.228536 4.943330
12 Negative dataset 0 0.252 DB00200 Hydroxocobalamin [N+]1=2[Co-3]345([N+]6=C7[C@H]([C@@](CC(=O)N)(... 4.329038 1.318275
13 Negative dataset 0 0.110 DB00207 Azithromycin CC[C@H]1OC(=O)[C@H](C)[C@@H](O[C@H]2C[C@@](C)(... 4.254687 4.930093
14 Negative dataset 0 0.208 DB00278 Argatroban C[C@@H]1CCN([C@H](C1)C(O)=O)C(=O)[C@H](CCCNC(N... -0.004274 -0.422700
In [454]:
df_vec.to_csv('probabilities_BA_prediction_curated.csv',index=False)

Case Three - ChemBL phase 0 vs phase 1-4

In [387]:
in_file_pos = './BA-Chembl-4-phases-smiles-no-repeat.csv'
in_file_neg = './non-BA-Chembl-phase0-descriptors.csv'
out_file = 'BA-chembl-phases-descriptors-vectors.csv'

test_X, test_df, test_classY = input_test_data(in_file_pos, in_file_neg, out_file)
test_X = test_X[:,8:]
Negative dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Postive dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Loading molecules.
RDKit ERROR: [15:46:07] SMILES Parse Error: syntax error for input: 'Smiles'
Keeping only molecules that can be processed by RDKit.
Featurizing molecules.
(9432, 300)
In [388]:
clf = BaggingClassifier(n_estimators=50, max_samples=1.0, max_features=0.4, random_state=RANDOM_STATE)
clf = RandomForestClassifier(n_estimators=500, max_features='auto', random_state=RANDOM_STATE)

plot_scatter(test_X, test_classY)

pred, prob, df_vec = run_classifier (clf, X, test_X, y, test_classY)


plt.subplots(figsize=(8, 8))
sns.violinplot(x="Class", y="Probability", data=df_vec, palette="muted")
plt.show()

plt.subplots(figsize=(8, 8))
sns.stripplot(x="Class", y="Probability", data=df_vec, palette="muted", jitter=0.10)
plt.show()
________________________________________________________________________________
Training: 
RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=500,
                       n_jobs=None, oob_score=False, random_state=458,
                       verbose=0, warm_start=False)
train time: 3.084s
In [389]:
title = 'Confusion matrix with input classes'
disp = plot_confusion_matrix(clf, test_X, test_classY,
                             display_labels=[0,1],
                             cmap=plt.cm.Blues)
disp.ax_.set_title(title)

print()
print(disp.confusion_matrix)
plt.show()
[[3590 2000]
 [2361 1481]]

Case Four - ChEMBL QED + RO-5 vs ChEMBL phase 1-4

In [390]:
in_file_pos = './BA-Chembl-4-phases-smiles-no-repeat.csv'
in_file_neg = './non-BA-qed-ro5-Chembl-smiles.csv'
out_file = 'non-BA-qed-ro5-descriptors-vectors.csv'

test_X, test_df, test_classY = input_test_data(in_file_pos, in_file_neg, out_file)
test_X = test_X[:,8:]
Negative dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Postive dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Loading molecules.
RDKit ERROR: [15:48:48] SMILES Parse Error: syntax error for input: 'Smiles'
Keeping only molecules that can be processed by RDKit.
Featurizing molecules.
(5839, 300)
In [391]:
clf = BaggingClassifier(n_estimators=50, max_samples=1.0, max_features=0.4, random_state=RANDOM_STATE)
clf = RandomForestClassifier(n_estimators=500, max_features='auto', random_state=RANDOM_STATE)

plot_scatter(test_X, test_classY)

pred, prob, df_vec = run_classifier (clf, X, test_X, y, test_classY)


plt.subplots(figsize=(8, 8))
sns.violinplot(x="Class", y="Probability", data=df_vec, palette="muted")
plt.show()

plt.subplots(figsize=(8, 8))
sns.stripplot(x="Class", y="Probability", data=df_vec, palette="muted", jitter=0.10)
plt.show()
________________________________________________________________________________
Training: 
RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=500,
                       n_jobs=None, oob_score=False, random_state=458,
                       verbose=0, warm_start=False)
train time: 3.158s
In [392]:
title = 'Confusion matrix with input classes'
disp = plot_confusion_matrix(clf, test_X, test_classY,
                             display_labels=[0,1],
                             cmap=plt.cm.Blues)
disp.ax_.set_title(title)

print()
print(disp.confusion_matrix)
plt.show()
[[1863  134]
 [2361 1481]]

With both structure and descriptors

Case One - HSDB vs ChemBL 4-phases

In [393]:
in_file_pos = './BA-Chembl-4-phases-smiles-no-repeat.csv'
in_file_neg = './hsbd-smiles-no-repeat.csv'
out_file = 'BA-vectors.csv'

test_X, test_df, test_classY = input_test_data(in_file_pos, in_file_neg, out_file)
Negative dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Postive dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Loading molecules.
RDKit ERROR: [15:51:51] SMILES Parse Error: syntax error for input: 'Smiles'
RDKit ERROR: [15:51:51] SMILES Parse Error: syntax error for input: 'nan'
RDKit ERROR: [15:51:51] SMILES Parse Error: syntax error for input: 'nan'
RDKit ERROR: [15:51:51] SMILES Parse Error: syntax error for input: 'nan'
RDKit ERROR: [15:51:51] SMILES Parse Error: syntax error for input: 'nan'
RDKit ERROR: [15:51:51] SMILES Parse Error: syntax error for input: 'nan'
RDKit ERROR: [15:51:51] SMILES Parse Error: syntax error for input: 'nan'
RDKit ERROR: [15:51:52] Explicit valence for atom # 1 Cl, 7, is greater than permitted
RDKit ERROR: [15:51:52] SMILES Parse Error: syntax error for input: 'Smiles'
Keeping only molecules that can be processed by RDKit.
Featurizing molecules.
(7365, 300)
In [394]:
clf = BaggingClassifier(n_estimators=50, max_samples=1.0, max_features=0.4, random_state=RANDOM_STATE)
clf = RandomForestClassifier(n_estimators=500, max_features='auto', random_state=RANDOM_STATE)

plot_scatter(test_X, test_classY)

pred, prob, df_vec = run_classifier (clf, X_, test_X, y, test_classY)


plt.subplots(figsize=(8, 8))
sns.violinplot(x="Class", y="Probability", data=df_vec, palette="muted")
plt.show()

plt.subplots(figsize=(8, 8))
sns.stripplot(x="Class", y="Probability", data=df_vec, palette="muted", jitter=0.10)
plt.show()
________________________________________________________________________________
Training: 
RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=500,
                       n_jobs=None, oob_score=False, random_state=458,
                       verbose=0, warm_start=False)
train time: 3.212s
In [395]:
title = 'Confusion matrix with input classes'
disp = plot_confusion_matrix(clf, test_X, test_classY,
                             display_labels=[0,1],
                             cmap=plt.cm.Blues)
disp.ax_.set_title(title)

print()
print(disp.confusion_matrix)
plt.show()
[[3424   99]
 [3801   41]]

Case Two - BA curated

In [396]:
in_file_pos = './BA-curated-descriptors.csv'
in_file_neg = './non-BA-curated-descriptors.csv'
out_file = 'BA-curated-descriptors-vectors.csv'

test_X, test_df, test_classY = input_test_data(in_file_pos, in_file_neg, out_file)
Negative dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Postive dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Loading molecules.
RDKit ERROR: [15:52:59] SMILES Parse Error: syntax error for input: 'nan'
RDKit ERROR: [15:52:59] SMILES Parse Error: syntax error for input: 'OC1=CC=CC(=C1)C-1=C2\CCC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC(O)=CC=C1)C1=CC(O)=CC=C1)\C1=CC(O)=CC=C1'
RDKit ERROR: [15:52:59] SMILES Parse Error: syntax error for input: 'OS(O)(O)C1=CC=C(C=C1)C-1=C2\C=CC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC=C(C=C1)S(O)(O)O)C1=CC=C(C=C1)S([O-])([O-])[O-])\C1=CC=C(C=C1)S(O)(O)[O-]'
RDKit ERROR: [15:52:59] Explicit valence for atom # 19 O, 3, is greater than permitted
Keeping only molecules that can be processed by RDKit.
Featurizing molecules.
(2336, 300)
In [397]:
clf = BaggingClassifier(n_estimators=50, max_samples=1.0, max_features=0.4, random_state=RANDOM_STATE)
clf = RandomForestClassifier(n_estimators=500, max_features='auto', random_state=RANDOM_STATE)

plot_scatter(test_X, test_classY)

pred, prob, df_vec = run_classifier (clf, X_, test_X, y, test_classY)


plt.subplots(figsize=(8, 8))
sns.violinplot(x="Class", y="Probability", data=df_vec, palette="muted")
plt.show()

plt.subplots(figsize=(8, 8))
sns.stripplot(x="Class", y="Probability", data=df_vec, palette="muted", jitter=0.10)
plt.show()
________________________________________________________________________________
Training: 
RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=500,
                       n_jobs=None, oob_score=False, random_state=458,
                       verbose=0, warm_start=False)
train time: 3.253s
In [398]:
title = 'Confusion matrix with input classes'
disp = plot_confusion_matrix(clf, test_X, test_classY,
                             display_labels=[0,1],
                             cmap=plt.cm.Blues)
disp.ax_.set_title(title)

print()
print(disp.confusion_matrix)
plt.show()
[[1441    6]
 [ 880    9]]

Case Three - ChEMBL phase 0 vs phase 1-4

In [399]:
in_file_pos = './BA-Chembl-4-phases-smiles-no-repeat.csv'
in_file_neg = './non-BA-Chembl-phase0-descriptors.csv'
out_file = 'BA-chembl-phases-descriptors-vectors.csv'

test_X, test_df, test_classY = input_test_data(in_file_pos, in_file_neg, out_file)
Negative dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Postive dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Loading molecules.
RDKit ERROR: [15:54:09] SMILES Parse Error: syntax error for input: 'Smiles'
Keeping only molecules that can be processed by RDKit.
Featurizing molecules.
(9432, 300)
In [400]:
clf = BaggingClassifier(n_estimators=50, max_samples=1.0, max_features=0.4, random_state=RANDOM_STATE)
clf = RandomForestClassifier(n_estimators=500, max_features='auto', random_state=RANDOM_STATE)

plot_scatter(test_X, test_classY)

pred, prob, df_vec = run_classifier (clf, X_, test_X, y, test_classY)


plt.subplots(figsize=(8, 8))
sns.violinplot(x="Class", y="Probability", data=df_vec, palette="muted")
plt.show()

plt.subplots(figsize=(8, 8))
sns.stripplot(x="Class", y="Probability", data=df_vec, palette="muted", jitter=0.10)
plt.show()
________________________________________________________________________________
Training: 
RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=500,
                       n_jobs=None, oob_score=False, random_state=458,
                       verbose=0, warm_start=False)
train time: 3.164s
In [401]:
title = 'Confusion matrix with input classes'
disp = plot_confusion_matrix(clf, test_X, test_classY,
                             display_labels=[0,1],
                             cmap=plt.cm.Blues)
disp.ax_.set_title(title)

print()
print(disp.confusion_matrix)
plt.show()
[[5583    7]
 [3801   41]]

Case Four - ChEMBL QED + RO-5 vs phase 1-4

In [402]:
in_file_pos = './BA-Chembl-4-phases-smiles-no-repeat.csv'
in_file_neg = './non-BA-qed-ro5-Chembl-smiles.csv'
out_file = 'non-BA-qed-ro5-descriptors-vectors.csv'

test_X, test_df, test_classY = input_test_data(in_file_pos, in_file_neg, out_file)
Negative dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Postive dataframe columns
Index(['ID', 'Name', 'Smiles', 'Class'], dtype='object')
Loading molecules.
RDKit ERROR: [15:55:37] SMILES Parse Error: syntax error for input: 'Smiles'
Keeping only molecules that can be processed by RDKit.
Featurizing molecules.
(5839, 300)
In [403]:
clf = BaggingClassifier(n_estimators=50, max_samples=1.0, max_features=0.4, random_state=RANDOM_STATE)

clf = RandomForestClassifier(n_estimators=500, max_features='auto', random_state=RANDOM_STATE)

plot_scatter(test_X, test_classY)

pred, prob, df_vec = run_classifier (clf, X_, test_X, y, test_classY)


plt.subplots(figsize=(8, 8))
sns.violinplot(x="Class", y="Probability", data=df_vec, palette="muted")
plt.show()

plt.subplots(figsize=(8, 8))
sns.stripplot(x="Class", y="Probability", data=df_vec, palette="muted", jitter=0.10)
plt.show()
________________________________________________________________________________
Training: 
RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=500,
                       n_jobs=None, oob_score=False, random_state=458,
                       verbose=0, warm_start=False)
train time: 3.274s
In [404]:
title = 'Confusion matrix with input classes'
disp = plot_confusion_matrix(clf, test_X, test_classY,
                             display_labels=[0,1],
                             cmap=plt.cm.Blues)
disp.ax_.set_title(title)

print()
print(disp.confusion_matrix)
plt.show()
[[1990    7]
 [3801   41]]

Tensorflow binary classification

In [405]:
RANDOM_STATE = 458
TEST_SIZE = 0.2

X_df_ = shuffle(X_df, random_state=RANDOM_STATE)

train, test = train_test_split(X_df_, test_size=TEST_SIZE, random_state=RANDOM_STATE)
test, val = train_test_split(test, test_size=0.5, random_state=RANDOM_STATE)
In [406]:
print(len(train), 'train examples')
print(len(val), 'validation examples')
print(len(test), 'test examples')
796 train examples
100 validation examples
99 test examples
In [407]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow import feature_column
In [408]:
# A utility method to create a tf.data dataset from a Pandas Dataframe
def df_to_dataset(dataframe, shuffle=True, batch_size=32):
    dataframe = dataframe.copy()
    labels = dataframe.pop('Class')
    dataframe = dataframe.drop(['PercentF'], axis=1)
    ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(dataframe))
    ds = ds.batch(batch_size)
    return ds
In [409]:
batch_size = 5 # A small batch sized is used for demonstration purposes
train_ds = df_to_dataset(train, batch_size=batch_size)
val_ds = df_to_dataset(val, shuffle=False, batch_size=batch_size)
test_ds = df_to_dataset(test, shuffle=False, batch_size=batch_size)
In [410]:
example_batch = next(iter(train_ds))[0]

#for feature_batch, label_batch in train_ds.take(1):
#    print('Every feature:', list(feature_batch.keys()))
#    print('A batch of ages:', feature_batch['MW'])
#    print('A batch of targets:', label_batch )
In [411]:
def demo(feature_column):
    feature_layer = layers.DenseFeatures(feature_column)
    print(feature_layer(example_batch).numpy())
In [412]:
#MW = feature_column.numeric_column("AROM")
#demo(MW)
In [413]:
feature_columns = []

# numeric cols
for header in ['MW', 'ALOGP', 'PSA']:
    feature_columns.append(feature_column.numeric_column(header))

for header in ['HBA', 'HBD', 'ROTB', 'ALERTS']:
    feature_columns.append(feature_column.numeric_column(header))

feature_columns.append(feature_column.numeric_column('AROM'))

for header in X_df.columns[10:].tolist():
    feature_columns.append(feature_column.numeric_column(header))


feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
In [414]:
batch_size = 32
train_ds = df_to_dataset(train, batch_size=batch_size)
val_ds = df_to_dataset(val, shuffle=False, batch_size=batch_size)
test_ds = df_to_dataset(test, shuffle=False, batch_size=batch_size)
In [415]:
model = tf.keras.Sequential([
    feature_layer,
    layers.Dense(256, activation='relu'),
    layers.Dense(128, activation='relu'),
    layers.Dense(128, activation='relu'),
    layers.Dense(1)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_ds,
          validation_data=val_ds,
          epochs=100)
Train for 25 steps, validate for 4 steps
Epoch 1/100
25/25 [==============================] - 4s 145ms/step - loss: 1.5553 - accuracy: 0.5264 - val_loss: 0.8365 - val_accuracy: 0.4300
Epoch 2/100
25/25 [==============================] - 1s 23ms/step - loss: 0.7551 - accuracy: 0.5088 - val_loss: 0.6791 - val_accuracy: 0.4700
Epoch 3/100
25/25 [==============================] - 1s 24ms/step - loss: 0.6635 - accuracy: 0.5540 - val_loss: 0.6446 - val_accuracy: 0.4700
Epoch 4/100
25/25 [==============================] - 1s 24ms/step - loss: 0.6634 - accuracy: 0.5779 - val_loss: 0.6664 - val_accuracy: 0.4700
Epoch 5/100
25/25 [==============================] - 1s 23ms/step - loss: 0.6417 - accuracy: 0.5666 - val_loss: 0.6654 - val_accuracy: 0.4700
Epoch 6/100
25/25 [==============================] - 1s 24ms/step - loss: 0.6518 - accuracy: 0.6068 - val_loss: 0.6557 - val_accuracy: 0.4700
Epoch 7/100
25/25 [==============================] - 1s 24ms/step - loss: 0.6922 - accuracy: 0.5590 - val_loss: 0.8798 - val_accuracy: 0.5900
Epoch 8/100
25/25 [==============================] - 1s 24ms/step - loss: 0.6726 - accuracy: 0.5590 - val_loss: 0.6345 - val_accuracy: 0.4800
Epoch 9/100
25/25 [==============================] - 1s 23ms/step - loss: 0.6695 - accuracy: 0.5980 - val_loss: 0.6078 - val_accuracy: 0.5200
Epoch 10/100
25/25 [==============================] - 1s 25ms/step - loss: 0.6375 - accuracy: 0.5879 - val_loss: 0.6135 - val_accuracy: 0.5300
Epoch 11/100
25/25 [==============================] - 1s 23ms/step - loss: 0.6019 - accuracy: 0.6143 - val_loss: 0.6249 - val_accuracy: 0.5800
Epoch 12/100
25/25 [==============================] - 1s 24ms/step - loss: 0.6134 - accuracy: 0.6319 - val_loss: 0.6412 - val_accuracy: 0.4900
Epoch 13/100
25/25 [==============================] - 1s 26ms/step - loss: 0.6183 - accuracy: 0.6131 - val_loss: 0.6715 - val_accuracy: 0.6100
Epoch 14/100
25/25 [==============================] - 1s 25ms/step - loss: 0.6171 - accuracy: 0.6432 - val_loss: 0.6533 - val_accuracy: 0.5300
Epoch 15/100
25/25 [==============================] - 1s 25ms/step - loss: 0.6436 - accuracy: 0.5917 - val_loss: 0.6194 - val_accuracy: 0.5900
Epoch 16/100
25/25 [==============================] - 1s 25ms/step - loss: 0.5834 - accuracy: 0.6495 - val_loss: 0.6169 - val_accuracy: 0.6000
Epoch 17/100
25/25 [==============================] - 1s 25ms/step - loss: 0.5841 - accuracy: 0.6671 - val_loss: 0.6120 - val_accuracy: 0.5700
Epoch 18/100
25/25 [==============================] - 1s 24ms/step - loss: 0.5683 - accuracy: 0.6922 - val_loss: 0.6158 - val_accuracy: 0.5900
Epoch 19/100
25/25 [==============================] - 1s 23ms/step - loss: 0.5737 - accuracy: 0.6746 - val_loss: 0.5997 - val_accuracy: 0.6300
Epoch 20/100
25/25 [==============================] - 1s 23ms/step - loss: 0.5548 - accuracy: 0.6709 - val_loss: 0.5811 - val_accuracy: 0.6300
Epoch 21/100
25/25 [==============================] - 1s 24ms/step - loss: 0.5569 - accuracy: 0.6809 - val_loss: 0.6116 - val_accuracy: 0.5900
Epoch 22/100
25/25 [==============================] - 1s 23ms/step - loss: 0.5793 - accuracy: 0.6784 - val_loss: 0.6960 - val_accuracy: 0.6500
Epoch 23/100
25/25 [==============================] - 1s 24ms/step - loss: 0.5651 - accuracy: 0.7035 - val_loss: 0.6502 - val_accuracy: 0.5100
Epoch 24/100
25/25 [==============================] - 1s 24ms/step - loss: 0.5540 - accuracy: 0.6784 - val_loss: 0.5751 - val_accuracy: 0.6500
Epoch 25/100
25/25 [==============================] - 1s 23ms/step - loss: 0.5549 - accuracy: 0.6960 - val_loss: 0.5472 - val_accuracy: 0.6200
Epoch 26/100
25/25 [==============================] - 1s 27ms/step - loss: 0.5172 - accuracy: 0.7010 - val_loss: 0.5675 - val_accuracy: 0.5800
Epoch 27/100
25/25 [==============================] - 1s 24ms/step - loss: 0.4914 - accuracy: 0.7286 - val_loss: 0.6290 - val_accuracy: 0.7000
Epoch 28/100
25/25 [==============================] - 1s 23ms/step - loss: 0.5054 - accuracy: 0.7186 - val_loss: 0.5915 - val_accuracy: 0.6100
Epoch 29/100
25/25 [==============================] - 1s 24ms/step - loss: 0.5170 - accuracy: 0.7362 - val_loss: 0.5801 - val_accuracy: 0.6200
Epoch 30/100
25/25 [==============================] - 1s 24ms/step - loss: 0.5081 - accuracy: 0.7236 - val_loss: 0.5822 - val_accuracy: 0.6700
Epoch 31/100
25/25 [==============================] - 1s 23ms/step - loss: 0.5046 - accuracy: 0.7010 - val_loss: 0.6439 - val_accuracy: 0.6600
Epoch 32/100
25/25 [==============================] - 1s 25ms/step - loss: 0.4855 - accuracy: 0.7538 - val_loss: 0.5528 - val_accuracy: 0.5900
Epoch 33/100
25/25 [==============================] - 1s 24ms/step - loss: 0.5265 - accuracy: 0.6922 - val_loss: 0.5551 - val_accuracy: 0.6000
Epoch 34/100
25/25 [==============================] - 1s 24ms/step - loss: 0.5086 - accuracy: 0.7161 - val_loss: 0.6523 - val_accuracy: 0.6700
Epoch 35/100
25/25 [==============================] - 1s 23ms/step - loss: 0.4808 - accuracy: 0.7450 - val_loss: 1.4297 - val_accuracy: 0.5700
Epoch 36/100
25/25 [==============================] - 1s 24ms/step - loss: 0.5247 - accuracy: 0.7073 - val_loss: 0.5981 - val_accuracy: 0.6400
Epoch 37/100
25/25 [==============================] - 1s 24ms/step - loss: 0.5156 - accuracy: 0.7073 - val_loss: 0.6164 - val_accuracy: 0.6600
Epoch 38/100
25/25 [==============================] - 1s 25ms/step - loss: 0.4833 - accuracy: 0.7462 - val_loss: 0.5804 - val_accuracy: 0.6400
Epoch 39/100
25/25 [==============================] - 1s 26ms/step - loss: 0.4722 - accuracy: 0.7349 - val_loss: 0.6087 - val_accuracy: 0.6700
Epoch 40/100
25/25 [==============================] - 1s 27ms/step - loss: 0.4761 - accuracy: 0.7500 - val_loss: 0.6080 - val_accuracy: 0.6500
Epoch 41/100
25/25 [==============================] - 1s 24ms/step - loss: 0.4451 - accuracy: 0.7588 - val_loss: 0.6212 - val_accuracy: 0.6700
Epoch 42/100
25/25 [==============================] - 1s 24ms/step - loss: 0.4685 - accuracy: 0.7487 - val_loss: 0.7213 - val_accuracy: 0.5200
Epoch 43/100
25/25 [==============================] - 1s 23ms/step - loss: 0.4856 - accuracy: 0.7525 - val_loss: 0.6876 - val_accuracy: 0.6800
Epoch 44/100
25/25 [==============================] - 1s 24ms/step - loss: 0.4462 - accuracy: 0.7701 - val_loss: 0.5940 - val_accuracy: 0.6300
Epoch 45/100
25/25 [==============================] - 1s 26ms/step - loss: 0.4165 - accuracy: 0.7802 - val_loss: 0.5262 - val_accuracy: 0.6200
Epoch 46/100
25/25 [==============================] - 1s 25ms/step - loss: 0.4366 - accuracy: 0.7688 - val_loss: 0.6134 - val_accuracy: 0.6700
Epoch 47/100
25/25 [==============================] - 1s 24ms/step - loss: 0.4054 - accuracy: 0.7915 - val_loss: 0.6709 - val_accuracy: 0.6700
Epoch 48/100
25/25 [==============================] - 1s 23ms/step - loss: 0.4338 - accuracy: 0.7638 - val_loss: 0.5380 - val_accuracy: 0.6200
Epoch 49/100
25/25 [==============================] - 1s 24ms/step - loss: 0.4061 - accuracy: 0.7764 - val_loss: 0.5777 - val_accuracy: 0.6400
Epoch 50/100
25/25 [==============================] - 1s 24ms/step - loss: 0.3789 - accuracy: 0.7965 - val_loss: 0.6875 - val_accuracy: 0.6700
Epoch 51/100
25/25 [==============================] - 1s 24ms/step - loss: 0.3920 - accuracy: 0.8028 - val_loss: 0.6341 - val_accuracy: 0.6900
Epoch 52/100
25/25 [==============================] - 1s 24ms/step - loss: 0.3701 - accuracy: 0.8128 - val_loss: 0.5770 - val_accuracy: 0.7100
Epoch 53/100
25/25 [==============================] - 1s 24ms/step - loss: 0.3791 - accuracy: 0.8040 - val_loss: 0.7139 - val_accuracy: 0.6900
Epoch 54/100
25/25 [==============================] - 1s 23ms/step - loss: 0.3704 - accuracy: 0.8103 - val_loss: 0.8032 - val_accuracy: 0.6900
Epoch 55/100
25/25 [==============================] - 1s 23ms/step - loss: 0.3408 - accuracy: 0.8229 - val_loss: 0.5735 - val_accuracy: 0.6600
Epoch 56/100
25/25 [==============================] - 1s 24ms/step - loss: 0.3605 - accuracy: 0.8090 - val_loss: 0.6108 - val_accuracy: 0.6900
Epoch 57/100
25/25 [==============================] - 1s 24ms/step - loss: 0.3488 - accuracy: 0.8166 - val_loss: 0.6426 - val_accuracy: 0.6700
Epoch 58/100
25/25 [==============================] - 1s 23ms/step - loss: 0.4030 - accuracy: 0.8103 - val_loss: 0.8083 - val_accuracy: 0.6700
Epoch 59/100
25/25 [==============================] - 1s 23ms/step - loss: 0.4313 - accuracy: 0.7575 - val_loss: 0.7012 - val_accuracy: 0.6700
Epoch 60/100
25/25 [==============================] - 1s 24ms/step - loss: 0.3717 - accuracy: 0.8078 - val_loss: 0.6831 - val_accuracy: 0.6900
Epoch 61/100
25/25 [==============================] - 1s 23ms/step - loss: 0.3217 - accuracy: 0.8354 - val_loss: 0.5950 - val_accuracy: 0.7100
Epoch 62/100
25/25 [==============================] - 1s 23ms/step - loss: 0.3126 - accuracy: 0.8543 - val_loss: 0.6104 - val_accuracy: 0.7000
Epoch 63/100
25/25 [==============================] - 1s 24ms/step - loss: 0.3280 - accuracy: 0.8317 - val_loss: 0.6686 - val_accuracy: 0.6200
Epoch 64/100
25/25 [==============================] - 1s 25ms/step - loss: 0.3407 - accuracy: 0.8254 - val_loss: 0.6179 - val_accuracy: 0.6600
Epoch 65/100
25/25 [==============================] - 1s 23ms/step - loss: 0.3611 - accuracy: 0.8266 - val_loss: 0.6824 - val_accuracy: 0.6500
Epoch 66/100
25/25 [==============================] - 1s 25ms/step - loss: 0.3228 - accuracy: 0.8367 - val_loss: 0.7823 - val_accuracy: 0.5900
Epoch 67/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2929 - accuracy: 0.8505 - val_loss: 1.1227 - val_accuracy: 0.6600
Epoch 68/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2684 - accuracy: 0.8681 - val_loss: 0.7451 - val_accuracy: 0.6900
Epoch 69/100
25/25 [==============================] - 1s 26ms/step - loss: 0.3072 - accuracy: 0.8417 - val_loss: 0.7443 - val_accuracy: 0.6800
Epoch 70/100
25/25 [==============================] - 1s 23ms/step - loss: 0.3432 - accuracy: 0.8291 - val_loss: 0.6999 - val_accuracy: 0.6800
Epoch 71/100
25/25 [==============================] - 1s 24ms/step - loss: 0.3811 - accuracy: 0.7902 - val_loss: 0.6063 - val_accuracy: 0.6700
Epoch 72/100
25/25 [==============================] - 1s 24ms/step - loss: 0.3058 - accuracy: 0.8631 - val_loss: 0.8029 - val_accuracy: 0.6400
Epoch 73/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2876 - accuracy: 0.8731 - val_loss: 0.7200 - val_accuracy: 0.7100
Epoch 74/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2979 - accuracy: 0.8668 - val_loss: 0.7851 - val_accuracy: 0.6900
Epoch 75/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2918 - accuracy: 0.8631 - val_loss: 0.7981 - val_accuracy: 0.6400
Epoch 76/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2585 - accuracy: 0.8832 - val_loss: 1.0328 - val_accuracy: 0.6700
Epoch 77/100
25/25 [==============================] - 1s 26ms/step - loss: 0.2863 - accuracy: 0.8668 - val_loss: 0.8013 - val_accuracy: 0.6500
Epoch 78/100
25/25 [==============================] - 1s 24ms/step - loss: 0.3050 - accuracy: 0.8492 - val_loss: 0.9597 - val_accuracy: 0.6900
Epoch 79/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2367 - accuracy: 0.8894 - val_loss: 0.8489 - val_accuracy: 0.6900
Epoch 80/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2407 - accuracy: 0.8844 - val_loss: 1.1496 - val_accuracy: 0.6800
Epoch 81/100
25/25 [==============================] - 1s 26ms/step - loss: 0.2339 - accuracy: 0.8945 - val_loss: 1.0988 - val_accuracy: 0.6800
Epoch 82/100
25/25 [==============================] - 1s 25ms/step - loss: 0.2398 - accuracy: 0.8668 - val_loss: 0.8612 - val_accuracy: 0.6300
Epoch 83/100
25/25 [==============================] - 1s 24ms/step - loss: 0.3397 - accuracy: 0.8204 - val_loss: 0.8125 - val_accuracy: 0.6600
Epoch 84/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2763 - accuracy: 0.8618 - val_loss: 0.7859 - val_accuracy: 0.6800
Epoch 85/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2719 - accuracy: 0.8631 - val_loss: 0.8729 - val_accuracy: 0.7000
Epoch 86/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2344 - accuracy: 0.8882 - val_loss: 0.9691 - val_accuracy: 0.7100
Epoch 87/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2285 - accuracy: 0.8907 - val_loss: 0.9615 - val_accuracy: 0.6900
Epoch 88/100
25/25 [==============================] - 1s 25ms/step - loss: 0.2248 - accuracy: 0.8945 - val_loss: 0.8073 - val_accuracy: 0.6800
Epoch 89/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2672 - accuracy: 0.8781 - val_loss: 0.7299 - val_accuracy: 0.6700
Epoch 90/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2453 - accuracy: 0.8706 - val_loss: 1.1302 - val_accuracy: 0.6700
Epoch 91/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2821 - accuracy: 0.8681 - val_loss: 0.9174 - val_accuracy: 0.7000
Epoch 92/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2278 - accuracy: 0.9008 - val_loss: 1.0735 - val_accuracy: 0.6800
Epoch 93/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2144 - accuracy: 0.8957 - val_loss: 0.9469 - val_accuracy: 0.7100
Epoch 94/100
25/25 [==============================] - 1s 23ms/step - loss: 0.3370 - accuracy: 0.8329 - val_loss: 0.8210 - val_accuracy: 0.6500
Epoch 95/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2950 - accuracy: 0.8505 - val_loss: 0.8711 - val_accuracy: 0.6900
Epoch 96/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2431 - accuracy: 0.8832 - val_loss: 0.8675 - val_accuracy: 0.6900
Epoch 97/100
25/25 [==============================] - 1s 25ms/step - loss: 0.1818 - accuracy: 0.9234 - val_loss: 0.9817 - val_accuracy: 0.6700
Epoch 98/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2066 - accuracy: 0.9020 - val_loss: 0.9987 - val_accuracy: 0.6600
Epoch 99/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2166 - accuracy: 0.8982 - val_loss: 1.1112 - val_accuracy: 0.6600
Epoch 100/100
25/25 [==============================] - 1s 24ms/step - loss: 0.2104 - accuracy: 0.9020 - val_loss: 0.9918 - val_accuracy: 0.7200
Out[415]:
<tensorflow.python.keras.callbacks.History at 0x7f003a4b9048>
In [416]:
loss, accuracy = model.evaluate(test_ds)
print("Accuracy", accuracy)
4/4 [==============================] - 0s 6ms/step - loss: 1.0931 - accuracy: 0.6566
Accuracy 0.65656567
In [419]:
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(300,)),
    keras.layers.Dense(200, activation=tf.nn.relu),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(56, activation=tf.nn.relu),
    keras.layers.Dense(1, activation=tf.nn.sigmoid),
])
In [421]:
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

model.fit(X_train, y_train, epochs=200, batch_size=30)
test_loss, test_acc = model.evaluate(X_test, y_test)
print('Test accuracy:', test_acc)
Train on 796 samples
Epoch 1/200
796/796 [==============================] - 0s 311us/sample - loss: 0.2072 - accuracy: 0.9221
Epoch 2/200
796/796 [==============================] - 0s 33us/sample - loss: 0.1697 - accuracy: 0.9271
Epoch 3/200
796/796 [==============================] - 0s 44us/sample - loss: 0.1655 - accuracy: 0.9347
Epoch 4/200
796/796 [==============================] - 0s 36us/sample - loss: 0.1609 - accuracy: 0.9397
Epoch 5/200
796/796 [==============================] - 0s 34us/sample - loss: 0.1907 - accuracy: 0.9209
Epoch 6/200
796/796 [==============================] - 0s 41us/sample - loss: 0.1449 - accuracy: 0.9472
Epoch 7/200
796/796 [==============================] - 0s 36us/sample - loss: 0.1766 - accuracy: 0.9183
Epoch 8/200
796/796 [==============================] - 0s 33us/sample - loss: 0.1609 - accuracy: 0.9309
Epoch 9/200
796/796 [==============================] - 0s 40us/sample - loss: 0.1569 - accuracy: 0.9322
Epoch 10/200
796/796 [==============================] - 0s 36us/sample - loss: 0.1670 - accuracy: 0.9322
Epoch 11/200
796/796 [==============================] - 0s 34us/sample - loss: 0.1549 - accuracy: 0.9384
Epoch 12/200
796/796 [==============================] - 0s 39us/sample - loss: 0.1318 - accuracy: 0.9472
Epoch 13/200
796/796 [==============================] - 0s 33us/sample - loss: 0.1088 - accuracy: 0.9523
Epoch 14/200
796/796 [==============================] - 0s 33us/sample - loss: 0.1449 - accuracy: 0.9472
Epoch 15/200
796/796 [==============================] - 0s 39us/sample - loss: 0.1191 - accuracy: 0.9523
Epoch 16/200
796/796 [==============================] - 0s 34us/sample - loss: 0.1067 - accuracy: 0.9472
Epoch 17/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0960 - accuracy: 0.9598
Epoch 18/200
796/796 [==============================] - 0s 38us/sample - loss: 0.0879 - accuracy: 0.9661
Epoch 19/200
796/796 [==============================] - 0s 33us/sample - loss: 0.1014 - accuracy: 0.9598
Epoch 20/200
796/796 [==============================] - 0s 32us/sample - loss: 0.1124 - accuracy: 0.9573
Epoch 21/200
796/796 [==============================] - 0s 35us/sample - loss: 0.1337 - accuracy: 0.9435
Epoch 22/200
796/796 [==============================] - 0s 38us/sample - loss: 0.2238 - accuracy: 0.9083
Epoch 23/200
796/796 [==============================] - 0s 33us/sample - loss: 0.2705 - accuracy: 0.8995
Epoch 24/200
796/796 [==============================] - 0s 33us/sample - loss: 0.1817 - accuracy: 0.9397
Epoch 25/200
796/796 [==============================] - 0s 40us/sample - loss: 0.1188 - accuracy: 0.9422
Epoch 26/200
796/796 [==============================] - 0s 33us/sample - loss: 0.1108 - accuracy: 0.9535
Epoch 27/200
796/796 [==============================] - 0s 33us/sample - loss: 0.1287 - accuracy: 0.9535
Epoch 28/200
796/796 [==============================] - 0s 39us/sample - loss: 0.0943 - accuracy: 0.9686
Epoch 29/200
796/796 [==============================] - 0s 32us/sample - loss: 0.1103 - accuracy: 0.9447
Epoch 30/200
796/796 [==============================] - 0s 34us/sample - loss: 0.1160 - accuracy: 0.9497
Epoch 31/200
796/796 [==============================] - 0s 40us/sample - loss: 0.1412 - accuracy: 0.9460
Epoch 32/200
796/796 [==============================] - 0s 34us/sample - loss: 0.1432 - accuracy: 0.9384
Epoch 33/200
796/796 [==============================] - 0s 34us/sample - loss: 0.1206 - accuracy: 0.9472
Epoch 34/200
796/796 [==============================] - 0s 38us/sample - loss: 0.1293 - accuracy: 0.9472
Epoch 35/200
796/796 [==============================] - 0s 35us/sample - loss: 0.1062 - accuracy: 0.9535
Epoch 36/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0635 - accuracy: 0.9774
Epoch 37/200
796/796 [==============================] - 0s 38us/sample - loss: 0.0544 - accuracy: 0.9799
Epoch 38/200
796/796 [==============================] - 0s 43us/sample - loss: 0.0570 - accuracy: 0.9786
Epoch 39/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0598 - accuracy: 0.9736
Epoch 40/200
796/796 [==============================] - 0s 35us/sample - loss: 0.0462 - accuracy: 0.9874
Epoch 41/200
796/796 [==============================] - 0s 38us/sample - loss: 0.0482 - accuracy: 0.9824
Epoch 42/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0778 - accuracy: 0.9648
Epoch 43/200
796/796 [==============================] - 0s 38us/sample - loss: 0.1165 - accuracy: 0.9548
Epoch 44/200
796/796 [==============================] - 0s 38us/sample - loss: 0.0923 - accuracy: 0.9611
Epoch 45/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0620 - accuracy: 0.9774
Epoch 46/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0496 - accuracy: 0.9824
Epoch 47/200
796/796 [==============================] - 0s 40us/sample - loss: 0.0577 - accuracy: 0.9761
Epoch 48/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0431 - accuracy: 0.9824
Epoch 49/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0403 - accuracy: 0.9862
Epoch 50/200
796/796 [==============================] - 0s 40us/sample - loss: 0.0368 - accuracy: 0.9887
Epoch 51/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0454 - accuracy: 0.9824
Epoch 52/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0425 - accuracy: 0.9862
Epoch 53/200
796/796 [==============================] - 0s 39us/sample - loss: 0.0519 - accuracy: 0.9812
Epoch 54/200
796/796 [==============================] - 0s 35us/sample - loss: 0.1341 - accuracy: 0.9460
Epoch 55/200
796/796 [==============================] - 0s 33us/sample - loss: 0.1589 - accuracy: 0.9359
Epoch 56/200
796/796 [==============================] - 0s 34us/sample - loss: 0.1808 - accuracy: 0.9372
Epoch 57/200
796/796 [==============================] - 0s 40us/sample - loss: 0.4975 - accuracy: 0.8593
Epoch 58/200
796/796 [==============================] - 0s 33us/sample - loss: 0.3308 - accuracy: 0.8706
Epoch 59/200
796/796 [==============================] - 0s 33us/sample - loss: 0.1349 - accuracy: 0.9485
Epoch 60/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0821 - accuracy: 0.9648
Epoch 61/200
796/796 [==============================] - 0s 37us/sample - loss: 0.0535 - accuracy: 0.9837
Epoch 62/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0441 - accuracy: 0.9849
Epoch 63/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0403 - accuracy: 0.9887
Epoch 64/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0329 - accuracy: 0.9874
Epoch 65/200
796/796 [==============================] - 0s 38us/sample - loss: 0.0359 - accuracy: 0.9874
Epoch 66/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0388 - accuracy: 0.9862
Epoch 67/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0473 - accuracy: 0.9799
Epoch 68/200
796/796 [==============================] - 0s 43us/sample - loss: 0.0427 - accuracy: 0.9849
Epoch 69/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0363 - accuracy: 0.9862
Epoch 70/200
796/796 [==============================] - 0s 35us/sample - loss: 0.0301 - accuracy: 0.9887
Epoch 71/200
796/796 [==============================] - 0s 38us/sample - loss: 0.0439 - accuracy: 0.9824
Epoch 72/200
796/796 [==============================] - 0s 48us/sample - loss: 0.0289 - accuracy: 0.9899
Epoch 73/200
796/796 [==============================] - 0s 38us/sample - loss: 0.0374 - accuracy: 0.9862
Epoch 74/200
796/796 [==============================] - 0s 40us/sample - loss: 0.0471 - accuracy: 0.9862
Epoch 75/200
796/796 [==============================] - 0s 41us/sample - loss: 0.0404 - accuracy: 0.9849
Epoch 76/200
796/796 [==============================] - 0s 35us/sample - loss: 0.0286 - accuracy: 0.9874
Epoch 77/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0248 - accuracy: 0.9925
Epoch 78/200
796/796 [==============================] - 0s 39us/sample - loss: 0.0236 - accuracy: 0.9899
Epoch 79/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0281 - accuracy: 0.9899
Epoch 80/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0193 - accuracy: 0.9950
Epoch 81/200
796/796 [==============================] - 0s 38us/sample - loss: 0.0169 - accuracy: 0.9937
Epoch 82/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0138 - accuracy: 0.9975
Epoch 83/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0270 - accuracy: 0.9937
Epoch 84/200
796/796 [==============================] - 0s 32us/sample - loss: 0.0161 - accuracy: 0.9937
Epoch 85/200
796/796 [==============================] - 0s 39us/sample - loss: 0.0348 - accuracy: 0.9937
Epoch 86/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0126 - accuracy: 0.9987
Epoch 87/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0122 - accuracy: 0.9975
Epoch 88/200
796/796 [==============================] - 0s 38us/sample - loss: 0.0182 - accuracy: 0.9950
Epoch 89/200
796/796 [==============================] - 0s 35us/sample - loss: 0.0209 - accuracy: 0.9912
Epoch 90/200
796/796 [==============================] - 0s 38us/sample - loss: 0.0556 - accuracy: 0.9761
Epoch 91/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0977 - accuracy: 0.9648
Epoch 92/200
796/796 [==============================] - 0s 40us/sample - loss: 0.1957 - accuracy: 0.9359
Epoch 93/200
796/796 [==============================] - 0s 33us/sample - loss: 0.1374 - accuracy: 0.9347
Epoch 94/200
796/796 [==============================] - 0s 33us/sample - loss: 0.1916 - accuracy: 0.9284
Epoch 95/200
796/796 [==============================] - 0s 38us/sample - loss: 0.1474 - accuracy: 0.9447
Epoch 96/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0918 - accuracy: 0.9611
Epoch 97/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0356 - accuracy: 0.9899
Epoch 98/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0218 - accuracy: 0.9950
Epoch 99/200
796/796 [==============================] - 0s 35us/sample - loss: 0.0165 - accuracy: 0.9987
Epoch 100/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0139 - accuracy: 0.9962
Epoch 101/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0120 - accuracy: 0.9987
Epoch 102/200
796/796 [==============================] - 0s 35us/sample - loss: 0.0094 - accuracy: 1.0000
Epoch 103/200
796/796 [==============================] - 0s 39us/sample - loss: 0.0096 - accuracy: 0.9962
Epoch 104/200
796/796 [==============================] - 0s 42us/sample - loss: 0.0076 - accuracy: 0.9987
Epoch 105/200
796/796 [==============================] - 0s 35us/sample - loss: 0.0082 - accuracy: 0.9987
Epoch 106/200
796/796 [==============================] - 0s 46us/sample - loss: 0.0138 - accuracy: 0.9962
Epoch 107/200
796/796 [==============================] - 0s 47us/sample - loss: 0.0087 - accuracy: 0.9987
Epoch 108/200
796/796 [==============================] - 0s 37us/sample - loss: 0.0118 - accuracy: 0.9950
Epoch 109/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0152 - accuracy: 0.9937
Epoch 110/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0143 - accuracy: 0.9925
Epoch 111/200
796/796 [==============================] - 0s 39us/sample - loss: 0.0188 - accuracy: 0.9937
Epoch 112/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0186 - accuracy: 0.9950
Epoch 113/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0480 - accuracy: 0.9837
Epoch 114/200
796/796 [==============================] - 0s 42us/sample - loss: 0.1506 - accuracy: 0.9573
Epoch 115/200
796/796 [==============================] - 0s 35us/sample - loss: 0.4749 - accuracy: 0.8719
Epoch 116/200
796/796 [==============================] - 0s 32us/sample - loss: 0.3165 - accuracy: 0.8706
Epoch 117/200
796/796 [==============================] - 0s 38us/sample - loss: 0.1355 - accuracy: 0.9472
Epoch 118/200
796/796 [==============================] - 0s 41us/sample - loss: 0.0705 - accuracy: 0.9786
Epoch 119/200
796/796 [==============================] - 0s 35us/sample - loss: 0.0396 - accuracy: 0.9899
Epoch 120/200
796/796 [==============================] - 0s 32us/sample - loss: 0.0231 - accuracy: 0.9937
Epoch 121/200
796/796 [==============================] - 0s 37us/sample - loss: 0.0157 - accuracy: 1.0000
Epoch 122/200
796/796 [==============================] - 0s 37us/sample - loss: 0.0132 - accuracy: 0.9987
Epoch 123/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0120 - accuracy: 0.9987
Epoch 124/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0116 - accuracy: 0.9987
Epoch 125/200
796/796 [==============================] - 0s 41us/sample - loss: 0.0217 - accuracy: 0.9912
Epoch 126/200
796/796 [==============================] - 0s 32us/sample - loss: 0.0532 - accuracy: 0.9812
Epoch 127/200
796/796 [==============================] - 0s 33us/sample - loss: 0.1104 - accuracy: 0.9598
Epoch 128/200
796/796 [==============================] - 0s 40us/sample - loss: 0.1987 - accuracy: 0.9347
Epoch 129/200
796/796 [==============================] - 0s 34us/sample - loss: 0.1579 - accuracy: 0.9447
Epoch 130/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0871 - accuracy: 0.9686
Epoch 131/200
796/796 [==============================] - 0s 38us/sample - loss: 0.0550 - accuracy: 0.9849
Epoch 132/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0343 - accuracy: 0.9887
Epoch 133/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0213 - accuracy: 0.9962
Epoch 134/200
796/796 [==============================] - 0s 37us/sample - loss: 0.0169 - accuracy: 0.9975
Epoch 135/200
796/796 [==============================] - 0s 41us/sample - loss: 0.0133 - accuracy: 0.9987
Epoch 136/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0110 - accuracy: 0.9987
Epoch 137/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0084 - accuracy: 1.0000
Epoch 138/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0073 - accuracy: 0.9987
Epoch 139/200
796/796 [==============================] - 0s 37us/sample - loss: 0.0088 - accuracy: 0.9987
Epoch 140/200
796/796 [==============================] - 0s 45us/sample - loss: 0.0066 - accuracy: 0.9987
Epoch 141/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0063 - accuracy: 0.9987
Epoch 142/200
796/796 [==============================] - 0s 39us/sample - loss: 0.0086 - accuracy: 0.9975
Epoch 143/200
796/796 [==============================] - 0s 42us/sample - loss: 0.0076 - accuracy: 0.9975
Epoch 144/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0057 - accuracy: 1.0000
Epoch 145/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0054 - accuracy: 1.0000
Epoch 146/200
796/796 [==============================] - 0s 35us/sample - loss: 0.0046 - accuracy: 1.0000
Epoch 147/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0050 - accuracy: 0.9987
Epoch 148/200
796/796 [==============================] - 0s 32us/sample - loss: 0.0059 - accuracy: 0.9987
Epoch 149/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0109 - accuracy: 0.9975
Epoch 150/200
796/796 [==============================] - 0s 41us/sample - loss: 0.0064 - accuracy: 0.9975
Epoch 151/200
796/796 [==============================] - 0s 32us/sample - loss: 0.0046 - accuracy: 0.9987
Epoch 152/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0042 - accuracy: 0.9987
Epoch 153/200
796/796 [==============================] - 0s 39us/sample - loss: 0.0072 - accuracy: 0.9975
Epoch 154/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0041 - accuracy: 0.9987
Epoch 155/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0040 - accuracy: 0.9987
Epoch 156/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0041 - accuracy: 0.9987
Epoch 157/200
796/796 [==============================] - 0s 38us/sample - loss: 0.0030 - accuracy: 1.0000
Epoch 158/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0030 - accuracy: 1.0000
Epoch 159/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0033 - accuracy: 1.0000
Epoch 160/200
796/796 [==============================] - 0s 40us/sample - loss: 0.0038 - accuracy: 0.9987
Epoch 161/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0054 - accuracy: 0.9987
Epoch 162/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0070 - accuracy: 0.9987
Epoch 163/200
796/796 [==============================] - 0s 38us/sample - loss: 0.0056 - accuracy: 0.9975
Epoch 164/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0041 - accuracy: 0.9987
Epoch 165/200
796/796 [==============================] - 0s 35us/sample - loss: 0.0038 - accuracy: 0.9987
Epoch 166/200
796/796 [==============================] - 0s 38us/sample - loss: 0.0026 - accuracy: 1.0000
Epoch 167/200
796/796 [==============================] - 0s 35us/sample - loss: 0.0035 - accuracy: 0.9975
Epoch 168/200
796/796 [==============================] - 0s 32us/sample - loss: 0.0035 - accuracy: 0.9987
Epoch 169/200
796/796 [==============================] - 0s 39us/sample - loss: 0.0034 - accuracy: 0.9987
Epoch 170/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0026 - accuracy: 1.0000
Epoch 171/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0027 - accuracy: 0.9987
Epoch 172/200
796/796 [==============================] - 0s 40us/sample - loss: 0.0030 - accuracy: 0.9987
Epoch 173/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0038 - accuracy: 0.9975
Epoch 174/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0029 - accuracy: 0.9987
Epoch 175/200
796/796 [==============================] - 0s 46us/sample - loss: 0.0034 - accuracy: 0.9987
Epoch 176/200
796/796 [==============================] - 0s 40us/sample - loss: 0.0021 - accuracy: 0.9987
Epoch 177/200
796/796 [==============================] - 0s 41us/sample - loss: 0.0016 - accuracy: 1.0000
Epoch 178/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0018 - accuracy: 1.0000
Epoch 179/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0049 - accuracy: 0.9987
Epoch 180/200
796/796 [==============================] - 0s 39us/sample - loss: 0.0086 - accuracy: 0.9975
Epoch 181/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0410 - accuracy: 0.9849
Epoch 182/200
796/796 [==============================] - 0s 33us/sample - loss: 0.6212 - accuracy: 0.7827
Epoch 183/200
796/796 [==============================] - 0s 41us/sample - loss: 0.3779 - accuracy: 0.8291
Epoch 184/200
796/796 [==============================] - 0s 32us/sample - loss: 0.2979 - accuracy: 0.8693
Epoch 185/200
796/796 [==============================] - 0s 34us/sample - loss: 0.2046 - accuracy: 0.9146
Epoch 186/200
796/796 [==============================] - 0s 40us/sample - loss: 0.1602 - accuracy: 0.9347
Epoch 187/200
796/796 [==============================] - 0s 33us/sample - loss: 0.1186 - accuracy: 0.9447
Epoch 188/200
796/796 [==============================] - 0s 32us/sample - loss: 0.1010 - accuracy: 0.9611
Epoch 189/200
796/796 [==============================] - 0s 38us/sample - loss: 0.0628 - accuracy: 0.9749
Epoch 190/200
796/796 [==============================] - 0s 36us/sample - loss: 0.0585 - accuracy: 0.9812
Epoch 191/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0928 - accuracy: 0.9598
Epoch 192/200
796/796 [==============================] - 0s 36us/sample - loss: 0.1010 - accuracy: 0.9523
Epoch 193/200
796/796 [==============================] - 0s 39us/sample - loss: 0.0924 - accuracy: 0.9611
Epoch 194/200
796/796 [==============================] - 0s 32us/sample - loss: 0.1237 - accuracy: 0.9611
Epoch 195/200
796/796 [==============================] - 0s 32us/sample - loss: 0.1302 - accuracy: 0.9636
Epoch 196/200
796/796 [==============================] - 0s 41us/sample - loss: 0.1031 - accuracy: 0.9623
Epoch 197/200
796/796 [==============================] - 0s 33us/sample - loss: 0.0966 - accuracy: 0.9636
Epoch 198/200
796/796 [==============================] - 0s 34us/sample - loss: 0.0680 - accuracy: 0.9736
Epoch 199/200
796/796 [==============================] - 0s 40us/sample - loss: 0.0861 - accuracy: 0.9673
Epoch 200/200
796/796 [==============================] - 0s 34us/sample - loss: 0.1257 - accuracy: 0.9510
100/100 [==============================] - 0s 542us/sample - loss: 2.0607 - accuracy: 0.6800
Test accuracy: 0.68

Linear regression with BA probability prediction

In [51]:
from sklearn.linear_model import LinearRegression
import statsmodels.api as sm
from statsmodels.graphics.api import qqplot
In [52]:
df.columns
Out[52]:
Index(['ID', 'Name', 'Smiles', 'PercentF', 'Class', 'ROMol', 'QED', 'mol-sentence', 'mol2vec-000', 'mol2vec-001',
       ...
       'mol2vec-290', 'mol2vec-291', 'mol2vec-292', 'mol2vec-293', 'mol2vec-294', 'mol2vec-295', 'mol2vec-296', 'mol2vec-297', 'mol2vec-298', 'mol2vec-299'], dtype='object', length=308)
In [53]:
lm = LinearRegression()

cols_to_drop = ['ID', 'Name', 'Smiles', 'PercentF', 'Class', 'ROMol', 'QED', 'mol-sentence']
Y = df['PercentF'].astype('float32')
X = df.drop(cols_to_drop, axis=1)

lmfit = lm.fit(X, Y)
In [55]:
residuals = Y - lm.predict(X)
print("RSS : ", np.sum(residuals**2))
print("MSE : ", np.mean(residuals**2))
#Residuals are almost same as statsmodel.OLS
RSS :  544512.4
MSE :  547.2486
In [56]:
from sklearn.feature_selection import RFE  #Recursive(Backward) feature selection, takes all features and prunes out recursively.
rfe = RFE(estimator=lm, n_features_to_select=10, step=1)
rfe.fit(X, Y)
best_features = np.where(rfe.get_support())[0]
best_cols_names = [X.columns[i]  for i in best_features]
best_cols_names
Out[56]:
['mol2vec-070',
 'mol2vec-077',
 'mol2vec-086',
 'mol2vec-141',
 'mol2vec-162',
 'mol2vec-177',
 'mol2vec-206',
 'mol2vec-240',
 'mol2vec-248',
 'mol2vec-274']
In [57]:
from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import f_regression

skb = SelectKBest(f_regression, k=10)
skb.fit(X, Y)
best_features = np.where(skb.get_support())[0]
best_cols_names = [X.columns[i]  for i in best_features]
best_cols_names
Out[57]:
['mol2vec-055',
 'mol2vec-068',
 'mol2vec-083',
 'mol2vec-102',
 'mol2vec-156',
 'mol2vec-190',
 'mol2vec-212',
 'mol2vec-241',
 'mol2vec-251',
 'mol2vec-255']
In [58]:
X_c = sm.add_constant(X)
results = sm.OLS(Y,X_c).fit()
In [59]:
#Check if residuals are normally distributed
plt.xlabel('Residuals')
plt.ylabel('Frequency')
plt.xlim(-30,30)
plt.grid(True)
plt.hist(results.resid)
plt.show()
In [60]:
fig = sm.qqplot(results.resid, fit=True, line='r')
plt.show()
In [61]:
plt.plot(results.fittedvalues,  results.resid,'o')
plt.hlines(xmin=np.min(results.fittedvalues),xmax=np.max(results.fittedvalues),y=0)
plt.xlabel('Fitted Values')
plt.ylabel('Residuals')
plt.show()
In [62]:
from statsmodels.graphics.regressionplots import plot_leverage_resid2
fig = plot_leverage_resid2(results)
plt.show()
In [63]:
from statsmodels.graphics.regressionplots import plot_leverage_resid2
fig = plot_leverage_resid2(results)
plt.show()
In [64]:
# Cooks distance -  is also used estimate of the influence of a data point
influence = results.get_influence()
#c is the distance and p is p-value
(c, p) = influence.cooks_distance
plt.annotate("368",(368,c[368]))
plt.annotate("372",(372,c[372]))
plt.stem(np.arange(len(c)), c, markerfmt=",")
plt.show()

Test against a sample dataset from the customer

In [81]:
in_file = './sample-set.csv'
df_test = pd.read_csv(in_file, delimiter=',', usecols=[0, 1, 2, 3], names=['ID', 'Smiles', 'Name', 'PercentF'], header=None,encoding='latin-1')  # Assume <tab> separated
In [82]:
df_test['Name'] =  df_test['Name'].str.upper()
In [83]:
df_test
Out[83]:
ID Smiles Name PercentF
0 SS001 CC#CC(=O)N1CCC[C@H]1C1=NC(=C2N1C=CN=C2N)C1=CC=... CALQUENCE 25.0
1 SS002 OC(c1c(c2ccc1)nn(c3ccc([C@H]4CNCCC4)cc3)c2)=N;... ZEJULA 73.0
2 SS003 c1(ccc(cc1)C#N)NC(N[C@@H]1CCN([C@@H](c2[nH]c3c... DAURISMO 77.0
3 SS004 n1n(c(cc1C(F)(F)F)c1ccc(cc1)C)c1ccc(S(=O)(=O)N... CELEBREX 73.0
4 SS005 C1([C@@H]2[C@@H]3C=C[C@H]([C@@H]2C(N1NC(=O)c1c... TPOXX 100.0
5 SS006 CN(C)[C@H]1[C@@H]2C[C@H]3C(=C(O)[C@]2(O)C(=O)C... ACHROMYCIN V 77.0
6 SS007 C[C@]1([C@@](C(CO)=O)(O)CC2)[C@@H]2[C@H]3[C@@H... DELTA-CORTEF 75.0
7 SS008 CNCCCN1c2ccccc2CCc2ccccc12; CNCCCN1c2ccccc2CCc... NORPRAMIN 40.0
8 SS009 OC1N=C(c2ccccc2)c2cc(Cl)ccc2NC1=O; c1ccc(cc1)C... SERAX 92.8
9 SS010 C[C@@H]1[C@H]2[C@H](O)[C@H]3[C@H](N(C)C)C(O)=C... VIBRAMYCIN 93.0
10 SS011 ONC(O)=N; C(=O)(N)NO HYDREA 100.0
11 SS012 CC1Nc2cc(Cl)c(cc2C(=O)N1c1ccccc1C)S(N)(=O)=O; ... ZAROXOLYN 26.0
12 SS013 OC1N=C(c2ccccc2Cl)c2cc(Cl)ccc2NC1=O; c1ccc(c(c... ATIVAN 93.0
13 SS014 COc1cc(Cc2cnc(N)nc2N)cc(OC)c1OC; COc1cc(cc(c1O... PROLOPRIM 76.0
14 SS015 Cc1n(c2nn1)c(c3C(c4ccccc4)=NC2)ccc(Cl)c3; Cc1n... XANAX 88.0
15 SS016 Clc1cc(N2CCN(CCCN3C(=O)N(C4=N3)C=CC=C4)CC2)ccc... DESYREL 77.0
16 SS017 Cc1n(c2nn1)c(c3C(c4c(Cl)cccc4)=NC2)ccc(Cl)c3; ... HALCION 55.0
17 SS018 COc1c(C(O)=NCCc2ccc(S(NC(O)=NC3CCCCC3)(=O)=O)c... MICRONASE 95.0
18 SS019 Cc1ncc(C(O)=NCCc2ccc(S(NC(O)=NC3CCCCC3)(=O)=O)... GLUCOTROL 95.0
19 SS020 CCCC(O)=Nc1cc(C(C)=O)c(OCC(CNC(C)C)O)cc1; CCCC... SECTRAL 50.0
20 SS021 COc1c(OC)cc(c2c1)C(=N)NC(N3CCN(C(C4OCCC4)=O)CC... HYTRIN 82.0
21 SS022 COC(=O)C1=C(C)NC(C)=C(C1c1cccc(c1)N(=O)=O)C(=O... CARDENE 10.5
22 SS023 Nc1nc(cs1)C(=N\OCC(O)=O)\C(=O)N[C@H]1[C@H]2SCC... SUPRAX 47.0
23 SS024 CN1CCN(CC1)C1=Nc2cc(Cl)ccc2Nc2ccccc12; CN1CCN(... CLOZARIL 27.0
24 SS025 CC[C@H]1OC(=O)[C@H](C)[C@@H](O[C@H]2C[C@@](C)(... BIAXIN 55.0
25 SS026 CC[C@H]1OC(=O)[C@H](C)[C@@H](O[C@H]2C[C@@](C)(... ZITHROMAX 34.0
26 SS027 OC(=O)C1CCn2c1ccc2C(=O)c1ccccc1; c1ccc(cc1)C(=... TORADOL 100.0
27 SS028 O[C@@H]1[C@H]([C@@H]2OC1)OC[C@H]2ON(=O)=O; C1[... ISMO 93.0
28 SS029 CCC(C)n1ncn(-c2ccc(cc2)N2CCN(CC2)c2ccc(OCC3COC... SPORANOX 55.0
29 SS030 CO[C@H]1\C=C\O[C@@]2(C)Oc3c(C)c(O)c4C(=O)C(NC(... MYCOBUTIN 20.0
30 SS031 C1([C@@]2(O[C@@H]([C@H](C[C@H]2C)OC)[C@H](C[C@... PROGRAF 15.0
31 SS032 CN(C)C(=N)NC(N)=N; CN(C)C(=N)NC(=N)N GLUCOPHAGE 52.0
32 SS033 COc1cccc(c1)[C@@]1(O)CCCC[C@@H]1CN(C)C; CN(C)C... ULTRAM 70.0
33 SS034 CNS(=O)(=O)Cc1ccc2[nH]cc(CCN(C)C)c2c1; CNS(=O)... IMITREX 14.0
34 SS035 COc1ccccc1OCCNCC(O)COc1cccc2[nH]c3ccccc3c12; C... COREG 25.0
35 SS036 CC(C)(C)NC(=O)[C@@H]1C[C@@H]2CCCC[C@@H]2CN1C[C... INVIRASE 7.0
36 SS037 OC(=O)CC1(CC1)CS[C@H](CCc1ccccc1C(C)(C)O)c1ccc... SINGULAIR 62.0
37 SS038 CN(C)CCCC1(OCc2cc(ccc12)C#N)c1ccc(F)cc1; CN(C)... CELEXA 80.0
38 SS039 CC(=O)NC[C@H]1CN(C(=O)O1)c1ccc(N2CCOCC2)c(F)c1... ZYVOX 100.0
39 SS040 CC([C@@H]1CC[C@@H](C(O)=N[C@@H](C(O)=O)Cc2cccc... STARLIX 72.0
40 SS041 CNCC[C@@H](Oc1ccccc1C)c1ccccc1; Cc1ccccc1O[C@H... STRATTERA 63.0
41 SS042 CC(C[C@@H](CC(O)=O)CN)C; CC(C)C[C@@H](CC(=O)O)CN LYRICA 90.0
42 SS043 C[C@@H]1c(c2CCNC1)cc(Cl)cc2; C[C@H]1CNCCc2c1cc... BELVIQ 0.8
43 SS044 CN(C([C@@H]1C[C@@H](NC(c2sc(c3n2)CN(C)CC3)=O)[... SAVAYSA 62.0
44 SS045 CC(C(C(=O)N(C1CCCC1)c(c23)nc(Nc4ncc(N5CCNCC5)c... IBRANCE 46.0
45 SS046 COc1c(OC)cc(c2c1)[C@@H](CN(CCCN3C(=O)Cc(c4CC3)... CORLANOR 40.0
46 SS047 COc1c(C(O)=O)cc(CN(C([C@H](Cc2c(C)cc(C(N)=O)cc... VIBERZI 1.1
47 SS048 CC(C[C@@H](B(O)O)N=C(CN=C(c1c(Cl)ccc(Cl)c1)O)O)C NINLARO 58.0
48 SS049 CC(COc1ccc(CN=C(N(C2CCN(C)CC2)Cc3ccc(F)cc3)O)c... NUPLAZID 53.0
49 SS050 COc1c(OC)cc(c2c1)[C@H](N3CC2)C[C@@H](OC([C@H](... INGREZZA 49.0
50 SS051 COC1=CC=C(NC(=O)C2=CC=C(C=C2)C(=N)N(C)C)C(=C1)... BEVYXXA 34.0
51 SS052 O.[Fe+3].OC(CC(C(O)=O)(CC(O)=O)O)=O FERRIC CITRATE 0.7
52 SS053 COCCOC(=O)C1=C(C)NC(C)=C(C1c1cccc(c1)N(=O)=O)C... NIMODIPINE 12.0
53 SS054 C[C@H]1[C@@H](N(c2c(c3ncn2)cc[nH]3)C)CN(C(CC#N... XELJANZ 74.0
54 SS055 COc1ccc2nccc([C@H](O)[C@H]3C[C@@H]4CCN3C[C@@H]... QUINIDINE 75.0
55 SS056 C[C@H]([C@](c1c(F)cc(F)cc1)(Cn2ncnc2)O)c3c(F)c... VORICONAZOLE 96.0
In [84]:
in_file = './sample-smiles.csv'
df_perc = pd.read_csv(in_file, delimiter=',', usecols=[0, 1, 2, 3], names=['ID', 'Name_alt', 'Name', 'Smiles'], header=None,encoding='latin-1')  # Assume <tab> separated
In [85]:
df_perc
Out[85]:
ID Name_alt Name Smiles
0 1216 TRIMETHOPRIM PROLOPRIM COc1cc(Cc2cnc(N)nc2N)cc(OC)c1OC
1 2261 CLOZAPINE CLOZARIL CN1CCN(C2=Nc3cc(Cl)ccc3Nc3ccccc32)CC1
2 4730 HYDROXYUREA HYDREA NC(=O)NO
3 5582 GLYBURIDE MICRONASE COc1ccc(Cl)cc1C(=O)NCCc1ccc(S(=O)(=O)NC(=O)NC2...
4 12428 AZITHROMYCIN ZITHROMAX CC[C@H]1OC(=O)[C@H](C)[C@@H](O[C@H]2C[C@@](C)(...
5 16506 OXAZEPAM SERAX O=C1Nc2ccc(Cl)cc2C(c2ccccc2)=NC1O
6 16959 LORAZEPAM ATIVAN O=C1Nc2ccc(Cl)cc2C(c2ccccc2Cl)=NC1O
7 18694 CELECOXIB CELEBREX Cc1ccc(-c2cc(C(F)(F)F)nn2-c2ccc(S(N)(=O)=O)cc2...
8 24418 SAQUINAVIR MESYLATE INVIRASE CC(C)(C)NC(=O)[C@@H]1C[C@@H]2CCCC[C@@H]2CN1C[C...
9 26280 LINEZOLID ZYVOX CC(=O)NC[C@H]1CN(c2ccc(N3CCOCC3)c(F)c2)C(=O)O1
10 27111 SUMATRIPTAN IMITREX CNS(=O)(=O)Cc1ccc2[nH]cc(CCN(C)C)c2c1
11 27308 VORICONAZOLE VORICONAZOLE C[C@@H](c1ncncc1F)[C@](O)(Cn1cncn1)c1ccc(F)cc1F
12 27370 PREDNISOLONE DELTA-CORTEF C[C@]12C[C@H](O)[C@H]3[C@@H](CCC4=CC(=O)C=C[C@...
13 27419 TRIAZOLAM HALCION Cc1nnc2n1-c1ccc(Cl)cc1C(c1ccccc1Cl)=NC2
14 27648 ALPRAZOLAM XANAX Cc1nnc2n1-c1ccc(Cl)cc1C(c1ccccc1)=NC2
15 36662 CARVEDILOL COREG COc1ccccc1OCCNCC(O)COc1cccc2[nH]c3ccccc3c12
16 46727 NATEGLINIDE STARLIX CC(C)[C@H]1CC[C@H](C(=O)N[C@H](Cc2ccccc2)C(=O)...
17 73548 METOLAZONE ZAROXOLYN Cc1ccccc1N1C(=O)c2cc(S(N)(=O)=O)c(Cl)cc2NC1C
18 97867 ITRACONAZOLE SPORANOX CCC(C)n1ncn(-c2ccc(N3CCN(c4ccc(OCC5COC(Cn6cncn...
19 136161 PREGABALIN LYRICA CC(C)C[C@H](CN)CC(=O)O
20 139286 GLIPIZIDE GLUCOTROL Cc1cnc(C(=O)NCCc2ccc(S(=O)(=O)NC(=O)NC3CCCCC3)...
21 263453 ISOSORBIDE MONONITRATE ISMO O=[N+]([O-])O[C@@H]1CO[C@@H]2[C@@H](O)CO[C@H]12
22 319858 PALBOCICLIB IBRANCE CC(=O)c1c(C)c2cnc(Nc3ccc(N4CCNCC4)cn3)nc2n(C2C...
23 365248 NIMODIPINE NIMODIPINE COCCOC(=O)C1=C(C)NC(C)=C(C(=O)OC(C)C)C1c1cccc(...
24 397074 RIFABUTIN MYCOBUTIN CO[C@H]1/C=C/O[C@@]2(C)Oc3c(C)c(O)c4c(c3C2=O)C...
25 430642 CEFIXIME SUPRAX C=CC1=C(C(=O)O)N2C(=O)[C@@H](NC(=O)/C(=N\OCC(=...
26 455284 TETRACYCLINE HYDROCHLORIDE ACHROMYCIN V CN(C)[C@@H]1C(O)=C(C(N)=O)C(=O)[C@@]2(O)C(O)=C...
27 459679 BETRIXABAN BEVYXXA COc1ccc(NC(=O)c2ccc(C(=N)N(C)C)cc2)c(C(=O)Nc2c...
28 545075 DESIPRAMINE HYDROCHLORIDE NORPRAMIN CNCCCN1c2ccccc2CCc2ccccc21.Cl
29 547576 ATOMOXETINE HYDROCHLORIDE STRATTERA CNCC[C@@H](Oc1ccccc1C)c1ccccc1.Cl
30 547681 METFORMIN HYDROCHLORIDE GLUCOPHAGE CN(C)C(=N)N=C(N)N.Cl
31 570147 CLARITHROMYCIN BIAXIN CC[C@H]1OC(=O)[C@H](C)[C@@H](O[C@H]2C[C@@](C)(...
32 674277 NICARDIPINE HYDROCHLORIDE CARDENE COC(=O)C1=C(C)NC(C)=C(C(=O)OCCN(C)Cc2ccccc2)C1...
33 674632 MONTELUKAST SODIUM SINGULAIR CC(C)(O)c1ccccc1CC[C@@H](SCC1(CC(=O)[O-])CC1)c...
34 674650 DOXYCYCLINE VIBRAMYCIN C[C@H]1c2cccc(O)c2C(=O)C2=C(O)[C@]3(O)C(=O)C(C...
35 674732 CITALOPRAM HYDROBROMIDE CELEXA Br.CN(C)CCCC1(c2ccc(F)cc2)OCc2cc(C#N)ccc21
36 674749 TRAZODONE HYDROCHLORIDE DESYREL Cl.O=c1n(CCCN2CCN(c3cccc(Cl)c3)CC2)nc2ccccn12
37 674764 ACEBUTOLOL HYDROCHLORIDE SECTRAL CCCC(=O)Nc1ccc(OCC(O)CNC(C)C)c(C(C)=O)c1.Cl
38 675075 KETOROLAC TROMETHAMINE TORADOL NC(CO)(CO)CO.O=C(c1ccccc1)c1ccc2n1CCC2C(=O)O
39 675101 SUMATRIPTAN SUCCINATE IMITREX CNS(=O)(=O)Cc1ccc2[nH]cc(CCN(C)C)c2c1.O=C(O)CC...
40 699416 TRAMADOL HYDROCHLORIDE ULTRAM COc1cccc(C2(O)CCCCC2CN(C)C)c1.Cl
41 706068 TECOVIRIMAT TPOXX O=C(NN1C(=O)[C@H]2[C@H]3C=C[C@H]([C@@H]4C[C@H]...
42 1369607 LORCASERIN HYDROCHLORIDE BELVIQ C[C@H]1CNCCc2ccc(Cl)cc21.Cl
43 1376013 TOFACITINIB CITRATE XELJANZ C[C@@H]1CCN(C(=O)CC#N)C[C@@H]1N(C)c1ncnc2[nH]c...
44 1377952 EDOXABAN TOSYLATE SAVAYSA CN1CCc2nc(C(=O)N[C@@H]3C[C@@H](C(=O)N(C)C)CC[C...
45 1417028 IVABRADINE HYDROCHLORIDE CORLANOR COc1cc2c(cc1OC)CC(=O)N(CCCN(C)C[C@H]1Cc3cc(OC)...
46 1425665 ELUXADOLINE VIBERZI COc1ccc(CN(C(=O)[C@@H](N)Cc2c(C)cc(C(N)=O)cc2C...
47 1540465 DOXYCYCLINE CALCIUM VIBRAMYCIN C[C@H]1c2cccc([O-])c2C(=O)C2=C(O)[C@]3(O)C(=O)...
48 1590761 PIMAVANSERIN TARTRATE NUPLAZID CC(C)COc1ccc(CNC(=O)N(Cc2ccc(F)cc2)C2CCN(C)CC2...
49 1927281 IXAZOMIB CITRATE NINLARO CC(C)C[C@H](NC(=O)CNC(=O)c1cc(Cl)ccc1Cl)B1OC(=...
50 2039219 VALBENAZINE TOSYLATE INGREZZA COc1cc2c(cc1OC)[C@H]1C[C@@H](OC(=O)[C@@H](N)C(...
51 2039319 ACALABRUTINIB CALQUENCE CC#CC(=O)N1CCC[C@H]1c1nc(-c2ccc(C(=O)Nc3ccccn3...
52 2197433 TERAZOSIN HYDROCHLORIDE HYTRIN COc1cc2nc(N3CCN(C(=O)C4CCCO4)CC3)nc(N)c2cc1OC....
53 2197611 DOXYCYCLINE HYCLATE VIBRAMYCIN CCO.C[C@H]1c2cccc(O)c2C(=O)C2=C(O)[C@]3(O)C(=O...
54 2197758 TACROLIMUS PROGRAF C=CC[C@@H]1/C=C(\C)C[C@H](C)C[C@H](OC)[C@H]2O[...
55 2197793 NIRAPARIB TOSYLATE ZEJULA Cc1ccc(S(=O)(=O)O)cc1.NC(=O)c1cccc2cn(-c3ccc([...
56 2335505 GLASDEGIB MALEATE DAURISMO CN1CC[C@@H](NC(=O)Nc2ccc(C#N)cc2)C[C@@H]1c1nc2...
In [89]:
df_final = df_test.merge(df_perc, on='Name', how='inner', suffixes=('_1', '_2'))[['Name', 'PercentF', 'ID_2', 'Smiles_2']]
df_final = df_final.rename(columns={'ID_2': 'ID', 'Smiles_2': 'Smiles'})
df_final
Out[89]:
Name PercentF ID Smiles
0 CALQUENCE 25.0 2039319 CC#CC(=O)N1CCC[C@H]1c1nc(-c2ccc(C(=O)Nc3ccccn3...
1 ZEJULA 73.0 2197793 Cc1ccc(S(=O)(=O)O)cc1.NC(=O)c1cccc2cn(-c3ccc([...
2 DAURISMO 77.0 2335505 CN1CC[C@@H](NC(=O)Nc2ccc(C#N)cc2)C[C@@H]1c1nc2...
3 CELEBREX 73.0 18694 Cc1ccc(-c2cc(C(F)(F)F)nn2-c2ccc(S(N)(=O)=O)cc2...
4 TPOXX 100.0 706068 O=C(NN1C(=O)[C@H]2[C@H]3C=C[C@H]([C@@H]4C[C@H]...
5 ACHROMYCIN V 77.0 455284 CN(C)[C@@H]1C(O)=C(C(N)=O)C(=O)[C@@]2(O)C(O)=C...
6 DELTA-CORTEF 75.0 27370 C[C@]12C[C@H](O)[C@H]3[C@@H](CCC4=CC(=O)C=C[C@...
7 NORPRAMIN 40.0 545075 CNCCCN1c2ccccc2CCc2ccccc21.Cl
8 SERAX 92.8 16506 O=C1Nc2ccc(Cl)cc2C(c2ccccc2)=NC1O
9 VIBRAMYCIN 93.0 674650 C[C@H]1c2cccc(O)c2C(=O)C2=C(O)[C@]3(O)C(=O)C(C...
10 VIBRAMYCIN 93.0 1540465 C[C@H]1c2cccc([O-])c2C(=O)C2=C(O)[C@]3(O)C(=O)...
11 VIBRAMYCIN 93.0 2197611 CCO.C[C@H]1c2cccc(O)c2C(=O)C2=C(O)[C@]3(O)C(=O...
12 HYDREA 100.0 4730 NC(=O)NO
13 ZAROXOLYN 26.0 73548 Cc1ccccc1N1C(=O)c2cc(S(N)(=O)=O)c(Cl)cc2NC1C
14 ATIVAN 93.0 16959 O=C1Nc2ccc(Cl)cc2C(c2ccccc2Cl)=NC1O
15 PROLOPRIM 76.0 1216 COc1cc(Cc2cnc(N)nc2N)cc(OC)c1OC
16 XANAX 88.0 27648 Cc1nnc2n1-c1ccc(Cl)cc1C(c1ccccc1)=NC2
17 DESYREL 77.0 674749 Cl.O=c1n(CCCN2CCN(c3cccc(Cl)c3)CC2)nc2ccccn12
18 HALCION 55.0 27419 Cc1nnc2n1-c1ccc(Cl)cc1C(c1ccccc1Cl)=NC2
19 MICRONASE 95.0 5582 COc1ccc(Cl)cc1C(=O)NCCc1ccc(S(=O)(=O)NC(=O)NC2...
20 GLUCOTROL 95.0 139286 Cc1cnc(C(=O)NCCc2ccc(S(=O)(=O)NC(=O)NC3CCCCC3)...
21 SECTRAL 50.0 674764 CCCC(=O)Nc1ccc(OCC(O)CNC(C)C)c(C(C)=O)c1.Cl
22 HYTRIN 82.0 2197433 COc1cc2nc(N3CCN(C(=O)C4CCCO4)CC3)nc(N)c2cc1OC....
23 CARDENE 10.5 674277 COC(=O)C1=C(C)NC(C)=C(C(=O)OCCN(C)Cc2ccccc2)C1...
24 SUPRAX 47.0 430642 C=CC1=C(C(=O)O)N2C(=O)[C@@H](NC(=O)/C(=N\OCC(=...
25 CLOZARIL 27.0 2261 CN1CCN(C2=Nc3cc(Cl)ccc3Nc3ccccc32)CC1
26 BIAXIN 55.0 570147 CC[C@H]1OC(=O)[C@H](C)[C@@H](O[C@H]2C[C@@](C)(...
27 ZITHROMAX 34.0 12428 CC[C@H]1OC(=O)[C@H](C)[C@@H](O[C@H]2C[C@@](C)(...
28 TORADOL 100.0 675075 NC(CO)(CO)CO.O=C(c1ccccc1)c1ccc2n1CCC2C(=O)O
29 ISMO 93.0 263453 O=[N+]([O-])O[C@@H]1CO[C@@H]2[C@@H](O)CO[C@H]12
30 SPORANOX 55.0 97867 CCC(C)n1ncn(-c2ccc(N3CCN(c4ccc(OCC5COC(Cn6cncn...
31 MYCOBUTIN 20.0 397074 CO[C@H]1/C=C/O[C@@]2(C)Oc3c(C)c(O)c4c(c3C2=O)C...
32 PROGRAF 15.0 2197758 C=CC[C@@H]1/C=C(\C)C[C@H](C)C[C@H](OC)[C@H]2O[...
33 GLUCOPHAGE 52.0 547681 CN(C)C(=N)N=C(N)N.Cl
34 ULTRAM 70.0 699416 COc1cccc(C2(O)CCCCC2CN(C)C)c1.Cl
35 IMITREX 14.0 27111 CNS(=O)(=O)Cc1ccc2[nH]cc(CCN(C)C)c2c1
36 IMITREX 14.0 675101 CNS(=O)(=O)Cc1ccc2[nH]cc(CCN(C)C)c2c1.O=C(O)CC...
37 COREG 25.0 36662 COc1ccccc1OCCNCC(O)COc1cccc2[nH]c3ccccc3c12
38 INVIRASE 7.0 24418 CC(C)(C)NC(=O)[C@@H]1C[C@@H]2CCCC[C@@H]2CN1C[C...
39 SINGULAIR 62.0 674632 CC(C)(O)c1ccccc1CC[C@@H](SCC1(CC(=O)[O-])CC1)c...
40 CELEXA 80.0 674732 Br.CN(C)CCCC1(c2ccc(F)cc2)OCc2cc(C#N)ccc21
41 ZYVOX 100.0 26280 CC(=O)NC[C@H]1CN(c2ccc(N3CCOCC3)c(F)c2)C(=O)O1
42 STARLIX 72.0 46727 CC(C)[C@H]1CC[C@H](C(=O)N[C@H](Cc2ccccc2)C(=O)...
43 STRATTERA 63.0 547576 CNCC[C@@H](Oc1ccccc1C)c1ccccc1.Cl
44 LYRICA 90.0 136161 CC(C)C[C@H](CN)CC(=O)O
45 BELVIQ 0.8 1369607 C[C@H]1CNCCc2ccc(Cl)cc21.Cl
46 SAVAYSA 62.0 1377952 CN1CCc2nc(C(=O)N[C@@H]3C[C@@H](C(=O)N(C)C)CC[C...
47 IBRANCE 46.0 319858 CC(=O)c1c(C)c2cnc(Nc3ccc(N4CCNCC4)cn3)nc2n(C2C...
48 CORLANOR 40.0 1417028 COc1cc2c(cc1OC)CC(=O)N(CCCN(C)C[C@H]1Cc3cc(OC)...
49 VIBERZI 1.1 1425665 COc1ccc(CN(C(=O)[C@@H](N)Cc2c(C)cc(C(N)=O)cc2C...
50 NINLARO 58.0 1927281 CC(C)C[C@H](NC(=O)CNC(=O)c1cc(Cl)ccc1Cl)B1OC(=...
51 NUPLAZID 53.0 1590761 CC(C)COc1ccc(CNC(=O)N(Cc2ccc(F)cc2)C2CCN(C)CC2...
52 INGREZZA 49.0 2039219 COc1cc2c(cc1OC)[C@H]1C[C@@H](OC(=O)[C@@H](N)C(...
53 BEVYXXA 34.0 459679 COc1ccc(NC(=O)c2ccc(C(=N)N(C)C)cc2)c(C(=O)Nc2c...
54 NIMODIPINE 12.0 365248 COCCOC(=O)C1=C(C)NC(C)=C(C(=O)OC(C)C)C1c1cccc(...
55 XELJANZ 74.0 1376013 C[C@@H]1CCN(C(=O)CC#N)C[C@@H]1N(C)c1ncnc2[nH]c...
56 VORICONAZOLE 96.0 27308 C[C@@H](c1ncncc1F)[C@](O)(Cn1cncn1)c1ccc(F)cc1F
In [90]:
model_path = './models/model_300dim.pkl'
out_file = 'BA-test-vectors.csv'
X_test, df_test = featurize(df_final, out_file, model_path, 2, uncommon='UNK')
Loading molecules.
Keeping only molecules that can be processed by RDKit.
Featurizing molecules.
(57, 300)
In [91]:
cols_to_drop = ['ID', 'Name', 'Smiles', 'PercentF', 'ROMol', 'QED', 'mol-sentence']
Y_test = df_test['PercentF'].astype('float32')
X_test = df_test.drop(cols_to_drop, axis=1)

residuals = Y_test - lm.predict(X_test)
In [92]:
print("RSS : ", np.sum(residuals**2))
print("MSE : ", np.mean(residuals**2))
RSS :  48218.46
MSE :  845.9379
In [93]:
plt.plot(lm.predict(X_test),  residuals,'o')
plt.hlines(xmin=np.min(lm.predict(X_test)),xmax=np.max(lm.predict(X_test)),y=0)
plt.xlabel('Fitted Values')
plt.ylabel('Residuals')
plt.show()
In [94]:
lm.predict(X_test)
Out[94]:
array([ 47.41197  ,  66.26637  ,   6.764305 , 105.81418  ,  88.62217  ,
        49.39244  ,  63.082233 ,  53.155247 ,  87.77145  ,  54.861435 ,
        40.871964 ,  36.397366 ,  70.155304 ,  61.980556 ,  81.05878  ,
        67.36632  ,  72.08817  ,  72.14193  ,  65.37557  ,  75.11885  ,
        97.5473   ,  45.92444  ,  37.614655 ,  11.953121 ,  39.44533  ,
        43.75448  ,  24.877884 ,  24.43826  ,  38.673847 ,  69.57985  ,
        77.736336 ,  33.393166 ,  23.733795 ,  38.540592 ,  45.582966 ,
        41.182705 ,  42.229855 ,  47.663445 ,  47.437477 ,  44.54653  ,
        97.40156  ,  86.12162  ,  48.451756 ,  68.08896  ,  82.015785 ,
        70.592575 ,  41.89241  ,  79.81784  ,  45.395706 ,  70.6928   ,
        90.33299  ,   7.3096695,  38.395206 ,  44.177162 ,  14.907818 ,
        96.96376  , 115.36467  ], dtype=float32)
In [97]:
preds = lm.predict(X_test)
index = range(len(preds))

plt.plot(index, preds, 'o')
plt.plot(index, Y_test,'x')
plt.vlines(ymin=np.min(preds),ymax=np.max(preds),x=0)
plt.hlines(xmin=np.min(index),xmax=np.max(index),y=0)
plt.xlabel('Example')
plt.ylabel('True Values:x, Prediction:o ')
plt.show()